Coverage for src / hodoku / solver / chains.py: 98%
302 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
1"""Chain solver: X-Chain, XY-Chain, Remote Pair, Turbot Fish.
3Nice Loop / AIC / Grouped Nice Loop are handled by tabling.py (TablingSolver),
4not here. This module mirrors only the portions of Java's ChainSolver that are
5actually wired up in step_finder.py.
6"""
8from __future__ import annotations
10from typing import TYPE_CHECKING
12from hodoku.core.grid import (
13 ALL_UNITS, BUDDIES, CELL_CONSTRAINTS,
14 CONSTRAINTS, Grid,
15)
16from hodoku.core.solution_step import SolutionStep
17from hodoku.core.types import SolutionType
19if TYPE_CHECKING:
20 from hodoku.config import StepSearchConfig
22_MAX_CHAIN = 20 # maximum chain length (number of nodes)
25# ---------------------------------------------------------------------------
26# Link building
27# ---------------------------------------------------------------------------
29def _build_x_links(grid: Grid, digit: int) -> list[list[tuple[int, bool]]]:
30 """Return links[cell] = [(neighbor, is_strong), ...] for one digit.
32 Strong link: the unit has exactly 2 candidates for this digit.
33 Block neighbors that also share a row or col with the source cell are
34 skipped — the row/col link already covers them (mirrors Java's constr==2
35 deduplication in getAllLinks).
36 """
37 cand_set = grid.candidate_sets[digit]
38 links: list[list[tuple[int, bool]]] = [[] for _ in range(81)]
40 for cell in range(81):
41 if not (cand_set >> cell & 1):
42 continue
43 r_c, c_c, _ = CONSTRAINTS[cell]
44 row_u, col_u, blk_u = CELL_CONSTRAINTS[cell]
46 for unit_idx in (row_u, col_u, blk_u):
47 is_strong = grid.free[unit_idx][digit] == 2
48 is_block = unit_idx >= 18
49 for nb in ALL_UNITS[unit_idx]:
50 if nb == cell:
51 continue
52 if not (cand_set >> nb & 1):
53 continue
54 if is_block:
55 r_nb, c_nb, _ = CONSTRAINTS[nb]
56 if r_nb == r_c or c_nb == c_c:
57 continue # already covered by row/col link
58 links[cell].append((nb, is_strong))
60 return links
63# ---------------------------------------------------------------------------
64# Sort key matching HoDoKu's SolutionStep.compareTo
65# ---------------------------------------------------------------------------
67def _elim_sort_key(step: SolutionStep) -> int:
68 """Weighted index sum used by HoDoKu to order steps with equal elim count.
70 Java's getCandidateString() sorts candidatesToDelete by (value, index) via
71 Collections.sort before compareTo/getIndexSumme is ever called. So the
72 weighted sum is computed over the *sorted* list, not insertion order.
74 Formula: sum += cand.index * offset + cand.value; offset starts at 1, +=80
75 """
76 total = 0
77 offset = 1
78 for c in sorted(step.candidates_to_delete, key=lambda c: (c.value, c.index)):
79 total += c.index * offset + c.value
80 offset += 80
81 return total
84def _step_sort_key(step: SolutionStep) -> tuple:
85 return (-len(step.candidates_to_delete), _elim_sort_key(step))
88# ---------------------------------------------------------------------------
89# Solver
90# ---------------------------------------------------------------------------
92class ChainSolver:
93 """X-Chain, XY-Chain, Remote Pair."""
95 def __init__(self, grid: Grid, search_config: StepSearchConfig | None = None) -> None:
96 self.grid = grid
97 if search_config is not None:
98 self._max_chain = search_config.chain_max_length
99 else:
100 self._max_chain = _MAX_CHAIN
102 def get_step(self, sol_type: SolutionType) -> SolutionStep | None:
103 if sol_type == SolutionType.TURBOT_FISH:
104 return self._find_turbot_fish()
105 if sol_type == SolutionType.X_CHAIN:
106 return self._find_x_chain()
107 if sol_type == SolutionType.XY_CHAIN:
108 return self._find_xy_chain()
109 if sol_type == SolutionType.REMOTE_PAIR:
110 return self._find_remote_pair()
111 return None
113 def find_all(self, sol_type: SolutionType) -> list[SolutionStep]:
114 if sol_type == SolutionType.TURBOT_FISH:
115 return self._find_x_chain_impl_all(SolutionType.TURBOT_FISH, max_nodes=4)
116 if sol_type == SolutionType.X_CHAIN:
117 return self._find_x_chain_impl_all(SolutionType.X_CHAIN, max_nodes=self._max_chain)
118 if sol_type == SolutionType.XY_CHAIN:
119 return self._find_xy_type_all(SolutionType.XY_CHAIN)
120 if sol_type == SolutionType.REMOTE_PAIR:
121 return self._find_xy_type_all(SolutionType.REMOTE_PAIR)
122 return []
124 # ------------------------------------------------------------------
125 # Turbot Fish (X-Chain restricted to 3 links / 4 nodes)
126 # ------------------------------------------------------------------
128 def _find_turbot_fish(self) -> SolutionStep | None:
129 """Find the best Turbot Fish (X-Chain with at most 3 links)."""
130 return self._find_x_chain_impl(SolutionType.TURBOT_FISH, max_nodes=4)
132 # ------------------------------------------------------------------
133 # X-Chain
134 # ------------------------------------------------------------------
136 def _find_x_chain(self) -> SolutionStep | None:
137 """Find the best X-Chain elimination.
139 Collects all valid chains, deduplicates by elimination set (keeping
140 shortest per set), then returns the one ranked first by HoDoKu's
141 comparator (most eliminations, then lowest weighted index sum).
142 """
143 return self._find_x_chain_impl(SolutionType.X_CHAIN, max_nodes=self._max_chain)
145 def _find_x_chain_impl(self, sol_type: SolutionType, max_nodes: int) -> SolutionStep | None:
146 grid = self.grid
147 # elim_key → (chain_length, step): shortest chain per elimination set
148 deletes_map: dict[tuple, tuple[int, SolutionStep]] = {}
150 for digit in range(1, 10):
151 cand_set = grid.candidate_sets[digit]
152 if not cand_set:
153 continue
155 links = _build_x_links(grid, digit)
157 tmp = cand_set
158 while tmp:
159 lsb = tmp & -tmp
160 start = lsb.bit_length() - 1
161 tmp ^= lsb
163 start_buddies = grid.candidate_sets[digit] & BUDDIES[start]
164 if not start_buddies:
165 continue
167 for nb, is_strong in links[start]:
168 if not is_strong:
169 continue # X-Chain must start with a strong link
171 chain: list[int] = [start, nb]
172 chain_set: set[int] = {nb} # tracks chain[1:] for lasso check
174 self._dfs_x(
175 digit, links, chain, chain_set,
176 strong_only=False, # just placed a strong link; next: weak
177 start=start,
178 start_buddies=start_buddies,
179 deletes_map=deletes_map,
180 sol_type=sol_type,
181 max_nodes=max_nodes,
182 )
184 if not deletes_map:
185 return None
187 steps = [step for _, step in deletes_map.values()]
188 steps.sort(key=_step_sort_key)
189 return steps[0]
191 def _find_x_chain_impl_all(self, sol_type: SolutionType, max_nodes: int) -> list[SolutionStep]:
192 """Like _find_x_chain_impl but returns ALL steps (sorted by quality)."""
193 grid = self.grid
194 deletes_map: dict[tuple, tuple[int, SolutionStep]] = {}
196 for digit in range(1, 10):
197 cand_set = grid.candidate_sets[digit]
198 if not cand_set:
199 continue
200 links = _build_x_links(grid, digit)
201 tmp = cand_set
202 while tmp:
203 lsb = tmp & -tmp
204 start = lsb.bit_length() - 1
205 tmp ^= lsb
206 start_buddies = grid.candidate_sets[digit] & BUDDIES[start]
207 if not start_buddies:
208 continue
209 for nb, is_strong in links[start]:
210 if not is_strong:
211 continue
212 chain: list[int] = [start, nb]
213 chain_set: set[int] = {nb}
214 self._dfs_x(
215 digit, links, chain, chain_set,
216 strong_only=False,
217 start=start,
218 start_buddies=start_buddies,
219 deletes_map=deletes_map,
220 sol_type=sol_type,
221 max_nodes=max_nodes,
222 )
224 steps = [step for _, step in deletes_map.values()]
225 steps.sort(key=_step_sort_key)
226 return steps
228 def _dfs_x(
229 self,
230 digit: int,
231 links: list[list[tuple[int, bool]]],
232 chain: list[int],
233 chain_set: set[int],
234 strong_only: bool,
235 start: int,
236 start_buddies: int,
237 deletes_map: dict,
238 sol_type: SolutionType = SolutionType.X_CHAIN,
239 max_nodes: int = _MAX_CHAIN,
240 ) -> None:
241 if len(chain) >= max_nodes:
242 return
244 current = chain[-1]
246 for nb, link_is_strong in links[current]:
247 if strong_only and not link_is_strong:
248 continue # must use strong link here
250 is_loop = False
251 if nb == start:
252 # Loop back to start — still valid as an X-Chain ending
253 # (Java: isLoop = true, falls through to chain check).
254 is_loop = True
255 elif nb in chain_set:
256 continue # lasso: revisiting a middle cell
258 # When !strong_only: a strong link is treated as weak (no chain check).
259 effective_strong = link_is_strong and strong_only
261 chain.append(nb)
262 if not is_loop:
263 chain_set.add(nb)
265 # Valid chain end: ended on a strong link, at least 3 nodes total.
266 if effective_strong and len(chain) > 2:
267 elim = start_buddies & BUDDIES[nb]
268 if elim:
269 self._record(digit, sol_type, chain, elim, deletes_map)
271 # Don't recurse further if this was a loop (Java: isLoop → stop)
272 if not is_loop:
273 self._dfs_x(
274 digit, links, chain, chain_set,
275 strong_only=not strong_only,
276 start=start,
277 start_buddies=start_buddies,
278 deletes_map=deletes_map,
279 sol_type=sol_type,
280 max_nodes=max_nodes,
281 )
283 chain.pop()
284 if not is_loop:
285 chain_set.discard(nb)
287 # ------------------------------------------------------------------
288 # XY-Chain
289 # ------------------------------------------------------------------
291 def _find_xy_chain(self) -> SolutionStep | None:
292 return self._find_xy_type(SolutionType.XY_CHAIN)
294 # ------------------------------------------------------------------
295 # Remote Pair
296 # ------------------------------------------------------------------
298 def _find_remote_pair(self) -> SolutionStep | None:
299 return self._find_xy_type(SolutionType.REMOTE_PAIR)
301 # ------------------------------------------------------------------
302 # Shared XY/RP search
303 # ------------------------------------------------------------------
305 def _find_xy_type(self, sol_type: SolutionType) -> SolutionStep | None:
306 """Find the best XY-Chain or Remote Pair.
308 XY-Chain: all cells bivalue; strong link = within-cell, weak = inter-cell.
309 Remote Pair: like XY-Chain but all cells must share identical candidate pair;
310 minimum 4 cells (8 chain nodes).
311 """
312 grid = self.grid
313 is_rp = sol_type == SolutionType.REMOTE_PAIR
314 deletes_map: dict[tuple, tuple[int, SolutionStep]] = {}
316 for start in range(81):
317 if grid.values[start] != 0:
318 continue
319 # Both XY-Chain and RP require bivalue start cell
320 cands = [d for d in range(1, 10) if grid.candidate_sets[d] >> start & 1]
321 if len(cands) != 2:
322 continue
323 start_mask = grid.candidates[start] # 9-bit candidate mask for RP matching
325 d1, d2 = cands[0], cands[1]
327 for start_cand, other_cand in ((d1, d2), (d2, d1)):
328 start_buddies = grid.candidate_sets[start_cand] & BUDDIES[start]
329 if not start_buddies:
330 continue
331 if is_rp:
332 start_buddies2 = grid.candidate_sets[other_cand] & BUDDIES[start]
334 # chain stores (cell, candidate) pairs
335 chain: list[tuple[int, int]] = [(start, start_cand), (start, other_cand)]
336 visited: set[int] = {start}
338 self._dfs_xy(
339 grid, chain, visited,
340 strong_only=False, # just used within-cell strong; next: inter-cell
341 start=start,
342 start_cand=start_cand,
343 start_buddies=start_buddies,
344 start_buddies2=start_buddies2 if is_rp else 0,
345 start_mask=start_mask,
346 sol_type=sol_type,
347 deletes_map=deletes_map,
348 )
350 if not deletes_map:
351 return None
353 steps = [step for _, step in deletes_map.values()]
354 steps.sort(key=_step_sort_key)
355 return steps[0]
357 def _find_xy_type_all(self, sol_type: SolutionType) -> list[SolutionStep]:
358 """Like _find_xy_type but returns ALL steps (sorted by quality)."""
359 grid = self.grid
360 is_rp = sol_type == SolutionType.REMOTE_PAIR
361 deletes_map: dict[tuple, tuple[int, SolutionStep]] = {}
363 for start in range(81):
364 if grid.values[start] != 0:
365 continue
366 cands = [d for d in range(1, 10) if grid.candidate_sets[d] >> start & 1]
367 if len(cands) != 2:
368 continue
369 start_mask = grid.candidates[start]
370 d1, d2 = cands[0], cands[1]
372 for start_cand, other_cand in ((d1, d2), (d2, d1)):
373 start_buddies = grid.candidate_sets[start_cand] & BUDDIES[start]
374 if not start_buddies:
375 continue
376 if is_rp:
377 start_buddies2 = grid.candidate_sets[other_cand] & BUDDIES[start]
378 else:
379 start_buddies2 = 0
380 chain: list[tuple[int, int]] = [(start, start_cand), (start, other_cand)]
381 visited: set[int] = {start}
382 self._dfs_xy(
383 grid, chain, visited,
384 strong_only=False,
385 start=start,
386 start_cand=start_cand,
387 start_buddies=start_buddies,
388 start_buddies2=start_buddies2,
389 start_mask=start_mask,
390 sol_type=sol_type,
391 deletes_map=deletes_map,
392 )
394 steps = [step for _, step in deletes_map.values()]
395 steps.sort(key=_step_sort_key)
396 return steps
398 def _dfs_xy(
399 self,
400 grid: Grid,
401 chain: list[tuple[int, int]],
402 visited: set[int],
403 strong_only: bool,
404 start: int,
405 start_cand: int,
406 start_buddies: int,
407 start_buddies2: int,
408 start_mask: int,
409 sol_type: SolutionType,
410 deletes_map: dict,
411 ) -> None:
412 if len(chain) >= self._max_chain:
413 return
415 current_cell, current_cand = chain[-1]
416 is_rp = sol_type == SolutionType.REMOTE_PAIR
418 if strong_only:
419 # Within-cell strong link: move to the other candidate in current_cell.
420 # (bivalue cell guaranteed — current_cell is always bivalue for XY/RP)
421 other_cand = next(
422 d for d in range(1, 10)
423 if d != current_cand and grid.candidate_sets[d] >> current_cell & 1
424 )
426 chain.append((current_cell, other_cand))
428 # Check: valid chain end when other_cand == start_cand and len > 2
429 # (mirrors: stackLevel > 1 && newLinkIsStrong && newLinkCandidate == startCandidate)
430 if other_cand == start_cand and len(chain) > 2:
431 if is_rp and len(chain) >= 8:
432 # Java only enters checkRemotePairs when startCandidate
433 # eliminations exist (m1/m2 != 0). Mirror that guard so
434 # the DFS exploration order—and thus which startCandidate
435 # wins the dedup race—matches HoDoKu's.
436 elim_rp = start_buddies & BUDDIES[current_cell]
437 if elim_rp:
438 self._check_rp(chain, start_cand, start_buddies, start_buddies2, deletes_map)
439 elif not is_rp:
440 elim = start_buddies & BUDDIES[current_cell]
441 if elim:
442 self._record_xy(start_cand, chain, elim, deletes_map)
444 self._dfs_xy(
445 grid, chain, visited, strong_only=False,
446 start=start, start_cand=start_cand,
447 start_buddies=start_buddies, start_buddies2=start_buddies2,
448 start_mask=start_mask, sol_type=sol_type, deletes_map=deletes_map,
449 )
450 chain.pop()
452 else:
453 # Inter-cell weak link: move to another bivalue cell with current_cand.
454 r_c, c_c, _ = CONSTRAINTS[current_cell]
455 row_u, col_u, blk_u = CELL_CONSTRAINTS[current_cell]
456 cand_set = grid.candidate_sets[current_cand]
458 for unit_idx in (row_u, col_u, blk_u):
459 is_block = unit_idx >= 18
460 for nb in ALL_UNITS[unit_idx]:
461 if nb == current_cell:
462 continue
463 if not (cand_set >> nb & 1):
464 continue
465 # Must be bivalue
466 if grid.candidates[nb].bit_count() != 2:
467 continue
468 # Remote Pair: must have same two candidates
469 if is_rp and grid.candidates[nb] != start_mask:
470 continue
471 # Block dedup: skip if also connected via row/col
472 if is_block:
473 r_nb, c_nb, _ = CONSTRAINTS[nb]
474 if r_nb == r_c or c_nb == c_c:
475 continue
476 if nb == start:
477 continue # nice loop
478 if nb in visited:
479 continue # lasso
481 chain.append((nb, current_cand))
482 visited.add(nb)
484 self._dfs_xy(
485 grid, chain, visited, strong_only=True,
486 start=start, start_cand=start_cand,
487 start_buddies=start_buddies, start_buddies2=start_buddies2,
488 start_mask=start_mask, sol_type=sol_type, deletes_map=deletes_map,
489 )
491 chain.pop()
492 visited.discard(nb)
494 def _check_rp(
495 self,
496 chain: list[tuple[int, int]],
497 start_cand: int,
498 start_buddies: int,
499 start_buddies2: int,
500 deletes_map: dict,
501 ) -> None:
502 """Record a Remote Pair step. Eliminates both candidates from shared buddies."""
503 # For a 4-cell chain (len=8): check start cell vs end cell only.
504 # For longer chains: also check all pairs of cells with opposite polarity.
505 # Positions in chain: 0,1 = cell0; 2,3 = cell1; 4,5 = cell2; ...
506 # Even node-pair indices (0,2,4,...) are the "entry" candidates.
507 # Pairs with opposite polarity: i and j where (j-i) ≡ 2 (mod 4) and j-i >= 6.
508 step = SolutionStep(SolutionType.REMOTE_PAIR)
509 other_cand = chain[1][1] # the other start candidate
511 step.add_value(start_cand)
512 step.add_value(other_cand)
514 n = len(chain)
515 elim1_mask = 0
516 elim2_mask = 0
517 cand_set1 = self.grid.candidate_sets[start_cand]
518 cand_set2 = self.grid.candidate_sets[other_cand]
519 for i in range(0, n, 2):
520 cell_i = chain[i][0]
521 for j in range(i + 6, n, 4):
522 cell_j = chain[j][0]
523 shared = BUDDIES[cell_i] & BUDDIES[cell_j]
524 elim1_mask |= shared & cand_set1
525 elim2_mask |= shared & cand_set2
527 if not elim1_mask and not elim2_mask:
528 return
530 tmp = elim1_mask
531 while tmp:
532 lsb = tmp & -tmp
533 step.add_candidate_to_delete(lsb.bit_length() - 1, start_cand)
534 tmp ^= lsb
535 tmp = elim2_mask
536 while tmp:
537 lsb = tmp & -tmp
538 step.add_candidate_to_delete(lsb.bit_length() - 1, other_cand)
539 tmp ^= lsb
541 key = tuple(sorted((c.index, c.value) for c in step.candidates_to_delete))
542 old = deletes_map.get(key)
543 if old is None or old[0] > len(chain):
544 deletes_map[key] = (len(chain), step)
546 def _record_xy(
547 self,
548 digit: int,
549 chain: list[tuple[int, int]],
550 elim_mask: int,
551 deletes_map: dict,
552 ) -> None:
553 """Build an XY-Chain step and store if shortest for its elimination set."""
554 step = SolutionStep(SolutionType.XY_CHAIN)
555 step.add_value(digit)
556 tmp = elim_mask
557 while tmp:
558 lsb = tmp & -tmp
559 step.add_candidate_to_delete(lsb.bit_length() - 1, digit)
560 tmp ^= lsb
562 key = tuple(sorted((c.index, c.value) for c in step.candidates_to_delete))
563 old = deletes_map.get(key)
564 if old is None or old[0] > len(chain):
565 deletes_map[key] = (len(chain), step)
567 # ------------------------------------------------------------------
568 # Shared helper
569 # ------------------------------------------------------------------
571 def _record(
572 self,
573 digit: int,
574 sol_type: SolutionType,
575 chain: list[int],
576 elim_mask: int,
577 deletes_map: dict,
578 ) -> None:
579 """Build a step and store it if it's the shortest for its elimination set."""
580 step = SolutionStep(sol_type)
581 step.add_value(digit)
582 tmp = elim_mask
583 while tmp:
584 lsb = tmp & -tmp
585 step.add_candidate_to_delete(lsb.bit_length() - 1, digit)
586 tmp ^= lsb
587 key = tuple(sorted((c.index, c.value) for c in step.candidates_to_delete))
588 old = deletes_map.get(key)
589 if old is None or old[0] > len(chain):
590 deletes_map[key] = (len(chain), step)