Coverage for src / hodoku / solver / als.py: 89%
611 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
1"""ALS solver: ALS-XZ, ALS-XY-Wing, ALS-XY-Chain, Death Blossom.
3Mirrors Java's AlsSolver, Als, RestrictedCommon, and the ALS/RC collection
4methods in SudokuStepFinder.
5"""
7from __future__ import annotations
9from functools import cmp_to_key
10from typing import TYPE_CHECKING
12from hodoku.core.grid import ALL_UNITS, BUDDIES, Grid
13from hodoku.core.solution_step import SolutionStep
14from hodoku.core.types import SolutionType
16if TYPE_CHECKING:
17 from hodoku.config import StepSearchConfig
19_MAX_RC = 50 # maximum RCs in an ALS-Chain (matches Java MAX_RC)
21# Ordinal map used by AlsComparator for tiebreaking on type
22_SOL_TYPE_ORDINALS = {t: i for i, t in enumerate(SolutionType)}
25# ---------------------------------------------------------------------------
26# Helpers
27# ---------------------------------------------------------------------------
29def _get_buddies_of_set(cell_set: int) -> int:
30 """Intersection of BUDDIES[c] for every cell c in cell_set.
32 Returns the set of all cells that see every cell in cell_set.
33 Returns 0 if cell_set is empty.
34 """
35 if not cell_set:
36 return 0
37 result = (1 << 81) - 1
38 tmp = cell_set
39 while tmp:
40 lsb = tmp & -tmp
41 result &= BUDDIES[lsb.bit_length() - 1]
42 tmp ^= lsb
43 return result
46# ---------------------------------------------------------------------------
47# ALS data structure
48# ---------------------------------------------------------------------------
50class Als:
51 """Almost Locked Set: N cells in one house with exactly N+1 candidates."""
53 __slots__ = (
54 'indices', 'candidates',
55 'indices_per_cand', 'buddies_per_cand', 'buddies_als_per_cand',
56 'buddies',
57 )
59 def __init__(self, indices: int, candidates: int) -> None:
60 self.indices = indices
61 self.candidates = candidates
62 self.indices_per_cand: list[int] = [0] * 10
63 self.buddies_per_cand: list[int] = [0] * 10
64 self.buddies_als_per_cand: list[int] = [0] * 10
65 self.buddies: int = 0
67 def compute_fields(self, grid: Grid) -> None:
68 """Compute derived sets after collection."""
69 cands = self.candidates
70 tmp = cands
71 while tmp:
72 lsb = tmp & -tmp
73 d = lsb.bit_length() # digit (1-9)
74 tmp ^= lsb
75 ipc = self.indices & grid.candidate_sets[d]
76 self.indices_per_cand[d] = ipc
77 bpc = _get_buddies_of_set(ipc) & ~self.indices & grid.candidate_sets[d]
78 self.buddies_per_cand[d] = bpc
79 self.buddies_als_per_cand[d] = bpc | ipc
80 self.buddies |= bpc
82 def get_chain_penalty(self) -> int:
83 """Chain distance penalty for using this ALS in a tabling chain.
85 Mirrors Als.getChainPenalty() in Java.
86 """
87 cand_size = self.candidates.bit_count()
88 if cand_size <= 1:
89 return 0
90 if cand_size == 2:
91 return 1
92 return (cand_size - 1) * 2
94 def __eq__(self, other: object) -> bool:
95 return isinstance(other, Als) and self.indices == other.indices
97 def __hash__(self) -> int:
98 return hash(self.indices)
101# ---------------------------------------------------------------------------
102# Restricted Common data structure
103# ---------------------------------------------------------------------------
105class RestrictedCommon:
106 """A Restricted Common between two ALSes.
108 actual_rc encoding: 0=none, 1=cand1 only, 2=cand2 only, 3=both.
109 """
111 __slots__ = ('als1', 'als2', 'cand1', 'cand2', 'actual_rc')
113 def __init__(self, als1: int, als2: int, cand1: int,
114 cand2: int = 0, actual_rc: int = 0) -> None:
115 self.als1 = als1
116 self.als2 = als2
117 self.cand1 = cand1
118 self.cand2 = cand2
119 self.actual_rc = actual_rc
121 def check_rc(self, prev: RestrictedCommon | None, first_try: bool) -> bool:
122 """Update actual_rc based on the previous RC; return True if valid."""
123 self.actual_rc = 1 if self.cand2 == 0 else 3
124 if prev is None:
125 if self.cand2 != 0:
126 self.actual_rc = 1 if first_try else 2
127 return self.actual_rc != 0
128 if prev.actual_rc == 1:
129 self.actual_rc = _check_rc_int(prev.cand1, 0, self.cand1, self.cand2)
130 elif prev.actual_rc == 2:
131 self.actual_rc = _check_rc_int(prev.cand2, 0, self.cand1, self.cand2)
132 elif prev.actual_rc == 3:
133 # Java passes cand1 twice here (case 3) — replicated faithfully.
134 self.actual_rc = _check_rc_int(prev.cand1, prev.cand1, self.cand1, self.cand2)
135 return self.actual_rc != 0
138def _check_rc_int(c11: int, c12: int, c21: int, c22: int) -> int:
139 """Remove ARC candidates {c11, c12} from PRC candidates {c21, c22}.
141 Returns 0 (none left), 1 (c21), 2 (c22), or 3 (both).
142 """
143 if c12 == 0:
144 # one ARC
145 if c22 == 0:
146 return 0 if c11 == c21 else 1
147 else:
148 if c11 == c22:
149 return 1
150 elif c11 == c21:
151 return 2
152 else:
153 return 3
154 else:
155 # two ARCs
156 if c22 == 0:
157 return 0 if (c11 == c21 or c12 == c21) else 1
158 else:
159 if (c11 == c21 and c12 == c22) or (c11 == c22 and c12 == c21):
160 return 0
161 elif c11 == c22 or c12 == c22:
162 return 1
163 elif c11 == c21 or c12 == c21:
164 return 2
165 else:
166 return 3
169# ---------------------------------------------------------------------------
170# ALS collection
171# ---------------------------------------------------------------------------
173def _collect_alses(grid: Grid) -> list[Als]:
174 """Enumerate all ALSes in the grid (including single bivalue cells).
176 Mirrors SudokuStepFinder.doGetAlses(onlyLargerThanOne=false).
177 Iteration: 27 units x 9 start positions each, recursive subset search.
178 """
179 alses: list[Als] = []
180 seen: set[int] = set()
182 for unit in ALL_UNITS:
183 n = len(unit)
184 for start_j in range(n):
185 _check_als_recursive(0, start_j, unit, n, grid, alses, seen, 0, 0)
187 for als in alses:
188 als.compute_fields(grid)
190 return alses
193def _check_als_recursive(
194 anzahl: int,
195 start_idx: int,
196 unit: tuple[int, ...],
197 n: int,
198 grid: Grid,
199 alses: list[Als],
200 seen: set[int],
201 index_set: int,
202 cand_acc: int,
203) -> None:
204 """Recursive ALS search over one house.
206 anzahl: number of cells already in the current set (0-based on entry).
207 """
208 anzahl += 1
209 if anzahl > n - 1:
210 return
211 for i in range(start_idx, n):
212 cell = unit[i]
213 if grid.values[cell] != 0:
214 continue
215 new_index_set = index_set | (1 << cell)
216 new_cands = cand_acc | grid.candidates[cell]
217 if new_cands.bit_count() - anzahl == 1:
218 if new_index_set not in seen:
219 seen.add(new_index_set)
220 alses.append(Als(new_index_set, new_cands))
221 _check_als_recursive(anzahl, i + 1, unit, n, grid, alses, seen,
222 new_index_set, new_cands)
225# ---------------------------------------------------------------------------
226# RC collection (forward-only, no overlap)
227# ---------------------------------------------------------------------------
229def _collect_rcs(
230 alses: list[Als],
231 allow_overlap: bool = False,
232) -> tuple[list[RestrictedCommon], list[int], list[int]]:
233 """Find all Restricted Commons between ALS pairs.
235 Forward-only (als2 index > als1 index).
236 When allow_overlap is False (default), overlapping ALS pairs are skipped.
237 When allow_overlap is True, overlapping pairs are allowed provided the RC
238 candidate does not appear in the overlap region (mirrors Java withOverlap).
240 Returns (rcs, start_indices, end_indices) where start_indices[i]..end_indices[i]
241 is the slice of rcs whose als1 == i.
242 """
243 rcs: list[RestrictedCommon] = []
244 n = len(alses)
245 start_indices = [0] * n
246 end_indices = [0] * n
248 for i in range(n):
249 als1 = alses[i]
250 start_indices[i] = len(rcs)
251 for j in range(i + 1, n):
252 als2 = alses[j]
253 overlap = als1.indices & als2.indices
254 if overlap and not allow_overlap:
255 continue
256 # Must share at least one candidate
257 common = als1.candidates & als2.candidates
258 if not common:
259 continue
260 rc_count = 0
261 new_rc: RestrictedCommon | None = None
262 tmp = common
263 while tmp:
264 lsb = tmp & -tmp
265 cand = lsb.bit_length() # digit
266 tmp ^= lsb
267 all_cand_cells = als1.indices_per_cand[cand] | als2.indices_per_cand[cand]
268 # RC candidate must not appear in the overlap region
269 if overlap and all_cand_cells & overlap:
270 continue
271 # RC check: all instances of cand in both ALSes must see each other.
272 common_buddies = (als1.buddies_als_per_cand[cand]
273 & als2.buddies_als_per_cand[cand])
274 if all_cand_cells & ~common_buddies:
275 continue # some cell doesn't see all others
276 if rc_count == 0:
277 new_rc = RestrictedCommon(i, j, cand)
278 rcs.append(new_rc)
279 else:
280 assert new_rc is not None
281 new_rc.cand2 = cand
282 rc_count += 1
283 if rc_count == 2:
284 break # max 2 RCs per pair
285 end_indices[i] = len(rcs)
287 return rcs, start_indices, end_indices
290# ---------------------------------------------------------------------------
291# Shared elimination helper
292# ---------------------------------------------------------------------------
294def _check_candidates_to_delete(
295 als1: Als, als2: Als,
296 r1: int = 0, r2: int = 0, r3: int = 0, r4: int = 0,
297) -> list[tuple[int, int]]:
298 """Find (cell, digit) eliminations common to als1 and als2, minus RC digits.
300 r1..r4 are RC candidates to exclude (0 means unused).
301 Returns list of (cell_index, digit) in ascending cell order.
302 """
303 elim_mask = als1.candidates & als2.candidates
304 for r in (r1, r2, r3, r4):
305 if r:
306 elim_mask &= ~(1 << (r - 1))
307 if not elim_mask:
308 return []
309 # Quick pre-check: common buddies must exist
310 if not (als1.buddies & als2.buddies):
311 return []
312 result: list[tuple[int, int]] = []
313 tmp = elim_mask
314 while tmp:
315 lsb = tmp & -tmp
316 cand = lsb.bit_length()
317 tmp ^= lsb
318 elim_cells = als1.buddies_per_cand[cand] & als2.buddies_per_cand[cand]
319 c = elim_cells
320 while c:
321 cl = c & -c
322 result.append((cl.bit_length() - 1, cand))
323 c ^= cl
324 return result
327def _check_doubly_linked_als(
328 als1: Als, als2: Als, rc1: int, rc2: int,
329) -> list[tuple[int, int]]:
330 """Locked-set eliminations when als1 and als2 share two RCs.
332 als1 minus {rc1, rc2} forms a locked set; eliminate its remaining
333 candidates from cells outside als2 that see all corresponding als1 cells.
334 """
335 remaining = als1.candidates & ~(1 << (rc1 - 1)) & ~(1 << (rc2 - 1))
336 if not remaining:
337 return []
338 result: list[tuple[int, int]] = []
339 tmp = remaining
340 while tmp:
341 lsb = tmp & -tmp
342 cand = lsb.bit_length()
343 tmp ^= lsb
344 elim_cells = als1.buddies_per_cand[cand] & ~als2.indices
345 c = elim_cells
346 while c:
347 cl = c & -c
348 result.append((cl.bit_length() - 1, cand))
349 c ^= cl
350 return result
353# ---------------------------------------------------------------------------
354# AlsComparator sort key (mirrors Java's AlsComparator)
355# ---------------------------------------------------------------------------
357def _als_index_count(step: SolutionStep) -> int:
358 return sum(bin(a[0]).count('1') for a in step.alses)
361def _als_index_summe(candidates: list) -> int:
362 """Weighted sum matching Java's getIndexSumme().
364 Java sorts candidatesToDelete by (value, index) before computing:
365 sum += index * offset + value, with offset starting at 1, incrementing by 80.
366 """
367 total = 0
368 offset = 1
369 for c in sorted(candidates, key=lambda c: (c.value, c.index)):
370 total += c.index * offset + c.value
371 offset += 80
372 return total
375def _als_cmp(s1: SolutionStep, s2: SolutionStep) -> int:
376 # 1. Most eliminations (descending)
377 d = len(s2.candidates_to_delete) - len(s1.candidates_to_delete)
378 if d:
379 return d
380 # 2. Equivalence check: same elimination set?
381 k1 = tuple(sorted((c.index, c.value) for c in s1.candidates_to_delete))
382 k2 = tuple(sorted((c.index, c.value) for c in s2.candidates_to_delete))
383 if k1 != k2:
384 # Not equivalent: sort by weighted index sum (ascending), matching Java's getIndexSumme
385 return (_als_index_summe(s1.candidates_to_delete)
386 - _als_index_summe(s2.candidates_to_delete))
387 # Equivalent: 3. Fewer ALSes
388 d = len(s1.alses) - len(s2.alses)
389 if d:
390 return d
391 # 4. Fewer total cells across all ALSes
392 d = _als_index_count(s1) - _als_index_count(s2)
393 if d:
394 return d
395 # 5. Type ordinal (ascending)
396 return _SOL_TYPE_ORDINALS[s1.type] - _SOL_TYPE_ORDINALS[s2.type]
399def _best_step(deletes_map: dict) -> SolutionStep | None:
400 steps = [step for _, step in deletes_map.values()]
401 if not steps:
402 return None
403 steps.sort(key=cmp_to_key(_als_cmp))
404 return steps[0]
407def _record_step(
408 step: SolutionStep,
409 als_count: int,
410 deletes_map: dict,
411) -> None:
412 """Store step if it's shortest (fewest ALS cells) for its elimination set."""
413 key = tuple(sorted((c.index, c.value) for c in step.candidates_to_delete))
414 old = deletes_map.get(key)
415 if old is None or old[0] > als_count:
416 deletes_map[key] = (als_count, step)
419# ---------------------------------------------------------------------------
420# Death Blossom RC-per-cell index
421# ---------------------------------------------------------------------------
423class _RCForDeathBlossom:
424 """Per-stem-cell index of which ALSes cover each candidate."""
426 __slots__ = ('cand_mask', 'als_per_candidate')
428 def __init__(self) -> None:
429 self.cand_mask: int = 0
430 self.als_per_candidate: list[list[int]] = [[] for _ in range(10)]
432 def add_als_for_candidate(self, als_idx: int, cand: int) -> None:
433 self.als_per_candidate[cand].append(als_idx)
434 self.cand_mask |= 1 << (cand - 1)
437# ---------------------------------------------------------------------------
438# Solver
439# ---------------------------------------------------------------------------
441class AlsSolver:
442 """ALS-XZ, ALS-XY-Wing, ALS-XY-Chain, Death Blossom."""
444 def __init__(self, grid: Grid, search_config: StepSearchConfig | None = None) -> None:
445 self.grid = grid
446 if search_config is not None:
447 self._allow_overlap = search_config.als_allow_overlap
448 else:
449 self._allow_overlap = False
451 def get_step(
452 self, sol_type: SolutionType, allow_overlap: bool | None = None
453 ) -> SolutionStep | None:
454 if allow_overlap is None:
455 allow_overlap = self._allow_overlap
456 if sol_type == SolutionType.ALS_XZ:
457 return self._find_als_xz()
458 if sol_type == SolutionType.ALS_XY_WING:
459 return self._find_als_xy_wing(allow_overlap=allow_overlap)
460 if sol_type == SolutionType.ALS_XY_CHAIN:
461 return self._find_als_xy_chain(allow_overlap=allow_overlap)
462 if sol_type == SolutionType.DEATH_BLOSSOM:
463 return self._find_death_blossom(allow_overlap=allow_overlap)
464 return None
466 def find_all(
467 self, sol_type: SolutionType, allow_overlap: bool | None = None
468 ) -> list[SolutionStep]:
469 if allow_overlap is None:
470 allow_overlap = self._allow_overlap
471 if sol_type == SolutionType.ALS_XZ:
472 return self._find_als_xz_all()
473 if sol_type == SolutionType.ALS_XY_WING:
474 return self._find_als_xy_wing_all(allow_overlap=allow_overlap)
475 if sol_type == SolutionType.ALS_XY_CHAIN:
476 return self._find_als_xy_chain_all(allow_overlap=allow_overlap)
477 if sol_type == SolutionType.DEATH_BLOSSOM:
478 return self._find_death_blossom_all(allow_overlap=allow_overlap)
479 return []
481 # ------------------------------------------------------------------
482 # ALS-XZ
483 # ------------------------------------------------------------------
485 def _find_als_xz(self) -> SolutionStep | None:
486 """Return the FIRST ALS-XZ step found (mirrors Java's onlyOne=true mode)."""
487 grid = self.grid
488 alses = _collect_alses(grid)
489 rcs, _, _ = _collect_rcs(alses)
491 for rc in rcs:
492 if rc.als1 >= rc.als2:
493 continue # forward only (always true in forward-only mode)
494 als1 = alses[rc.als1]
495 als2 = alses[rc.als2]
497 elims: list[tuple[int, int]] = []
499 # Singly-linked elimination: exclude rc.cand1
500 elims.extend(_check_candidates_to_delete(als1, als2, rc.cand1))
502 # Doubly-linked: also exclude rc.cand2, plus locked-set elims
503 if rc.cand2:
504 elims.extend(_check_candidates_to_delete(als1, als2, rc.cand2))
505 elims.extend(_check_doubly_linked_als(als1, als2, rc.cand1, rc.cand2))
506 elims.extend(_check_doubly_linked_als(als2, als1, rc.cand1, rc.cand2))
508 if not elims:
509 continue
511 step = SolutionStep(SolutionType.ALS_XZ)
512 step.add_als(als1.indices, als1.candidates)
513 step.add_als(als2.indices, als2.candidates)
514 for cell, cand in elims:
515 step.add_candidate_to_delete(cell, cand)
516 return step
518 return None
520 def _find_als_xz_all(self) -> list[SolutionStep]:
521 """Return ALL ALS-XZ steps (collect all, deduplicate by elimination set)."""
522 grid = self.grid
523 alses = _collect_alses(grid)
524 rcs, _, _ = _collect_rcs(alses)
525 deletes_map: dict = {}
527 for rc in rcs:
528 if rc.als1 >= rc.als2:
529 continue
530 als1 = alses[rc.als1]
531 als2 = alses[rc.als2]
533 elims: list[tuple[int, int]] = []
534 elims.extend(_check_candidates_to_delete(als1, als2, rc.cand1))
535 if rc.cand2:
536 elims.extend(_check_candidates_to_delete(als1, als2, rc.cand2))
537 elims.extend(_check_doubly_linked_als(als1, als2, rc.cand1, rc.cand2))
538 elims.extend(_check_doubly_linked_als(als2, als1, rc.cand1, rc.cand2))
539 if not elims:
540 continue
542 step = SolutionStep(SolutionType.ALS_XZ)
543 step.add_als(als1.indices, als1.candidates)
544 step.add_als(als2.indices, als2.candidates)
545 for cell, cand in elims:
546 step.add_candidate_to_delete(cell, cand)
547 _record_step(step, _als_index_count(step), deletes_map)
549 return [step for _, step in deletes_map.values()]
551 # ------------------------------------------------------------------
552 # ALS-XY-Wing
553 # ------------------------------------------------------------------
555 def _find_als_xy_wing(self, allow_overlap: bool = False) -> SolutionStep | None:
556 """Return the FIRST ALS-XY-Wing step found (mirrors Java's onlyOne=true mode)."""
557 grid = self.grid
558 alses = _collect_alses(grid)
559 rcs, _, _ = _collect_rcs(alses, allow_overlap=allow_overlap)
560 n_rcs = len(rcs)
562 for i in range(n_rcs):
563 rc1 = rcs[i]
564 for j in range(i + 1, n_rcs):
565 rc2 = rcs[j]
567 # Both singly-linked with the same candidate → skip
568 if rc1.cand2 == 0 and rc2.cand2 == 0 and rc1.cand1 == rc2.cand1:
569 continue
571 # Find pivot C: the ALS shared by both RCs
572 c_idx, a_idx, b_idx = _identify_pivot(rc1, rc2)
573 if c_idx is None:
574 continue
576 als_a = alses[a_idx]
577 als_b = alses[b_idx]
579 # A and B must not overlap (skip when allow_overlap=True)
580 if not allow_overlap and (als_a.indices & als_b.indices):
581 continue
582 # A must not be subset/superset of B
583 union_ab = als_a.indices | als_b.indices
584 if union_ab == als_a.indices or union_ab == als_b.indices:
585 continue
587 elims = _check_candidates_to_delete(
588 als_a, als_b,
589 rc1.cand1, rc1.cand2, rc2.cand1, rc2.cand2,
590 )
591 if not elims:
592 continue
594 step = SolutionStep(SolutionType.ALS_XY_WING)
595 step.add_als(als_a.indices, als_a.candidates)
596 step.add_als(als_b.indices, als_b.candidates)
597 step.add_als(alses[c_idx].indices, alses[c_idx].candidates)
598 for cell, cand in elims:
599 step.add_candidate_to_delete(cell, cand)
600 return step
602 return None
604 def _find_als_xy_wing_all(self, allow_overlap: bool = False) -> list[SolutionStep]:
605 """Return ALL ALS-XY-Wing steps (collect all, deduplicate by elimination set)."""
606 grid = self.grid
607 alses = _collect_alses(grid)
608 rcs, _, _ = _collect_rcs(alses, allow_overlap=allow_overlap)
609 n_rcs = len(rcs)
610 deletes_map: dict = {}
612 for i in range(n_rcs):
613 rc1 = rcs[i]
614 for j in range(i + 1, n_rcs):
615 rc2 = rcs[j]
616 if rc1.cand2 == 0 and rc2.cand2 == 0 and rc1.cand1 == rc2.cand1:
617 continue
618 c_idx, a_idx, b_idx = _identify_pivot(rc1, rc2)
619 if c_idx is None:
620 continue
621 als_a = alses[a_idx]
622 als_b = alses[b_idx]
623 if not allow_overlap and (als_a.indices & als_b.indices):
624 continue
625 union_ab = als_a.indices | als_b.indices
626 if union_ab == als_a.indices or union_ab == als_b.indices:
627 continue
628 elims = _check_candidates_to_delete(
629 als_a, als_b,
630 rc1.cand1, rc1.cand2, rc2.cand1, rc2.cand2,
631 )
632 if not elims:
633 continue
634 step = SolutionStep(SolutionType.ALS_XY_WING)
635 step.add_als(als_a.indices, als_a.candidates)
636 step.add_als(als_b.indices, als_b.candidates)
637 step.add_als(alses[c_idx].indices, alses[c_idx].candidates)
638 for cell, cand in elims:
639 step.add_candidate_to_delete(cell, cand)
640 _record_step(step, _als_index_count(step), deletes_map)
642 return [step for _, step in deletes_map.values()]
644 # ------------------------------------------------------------------
645 # ALS-XY-Chain
646 # ------------------------------------------------------------------
648 def _find_als_xy_chain(self, allow_overlap: bool = False) -> SolutionStep | None:
649 grid = self.grid
650 alses = _collect_alses(grid)
651 rcs, start_indices, end_indices = _collect_rcs(alses, allow_overlap=allow_overlap)
652 deletes_map: dict = {}
654 n_als = len(alses)
655 als_in_chain = [False] * n_als
656 chain: list[RestrictedCommon] = []
658 for i in range(n_als):
659 start_als = alses[i]
660 als_in_chain[i] = True
661 self._chain_recursive(
662 i, None, True,
663 alses, rcs, start_indices, end_indices,
664 als_in_chain, chain,
665 start_als, deletes_map,
666 )
667 als_in_chain[i] = False
669 return _best_step(deletes_map)
671 def _find_als_xy_chain_all(self, allow_overlap: bool = False) -> list[SolutionStep]:
672 """Return ALL ALS-XY-Chain steps (all entries from the deduplication map)."""
673 grid = self.grid
674 alses = _collect_alses(grid)
675 rcs, start_indices, end_indices = _collect_rcs(alses, allow_overlap=allow_overlap)
676 deletes_map: dict = {}
678 n_als = len(alses)
679 als_in_chain = [False] * n_als
680 chain: list[RestrictedCommon] = []
682 for i in range(n_als):
683 start_als = alses[i]
684 als_in_chain[i] = True
685 self._chain_recursive(
686 i, None, True,
687 alses, rcs, start_indices, end_indices,
688 als_in_chain, chain,
689 start_als, deletes_map,
690 )
691 als_in_chain[i] = False
693 return [step for _, step in deletes_map.values()]
695 def _chain_recursive(
696 self,
697 als_idx: int,
698 last_rc: RestrictedCommon | None,
699 first_try: bool,
700 alses: list[Als],
701 rcs: list[RestrictedCommon],
702 start_indices: list[int],
703 end_indices: list[int],
704 als_in_chain: list[bool],
705 chain: list[RestrictedCommon],
706 start_als: Als,
707 deletes_map: dict,
708 ) -> None:
709 if len(chain) >= _MAX_RC:
710 return
712 first_try_local = True
713 i = start_indices[als_idx]
714 while i < end_indices[als_idx]:
715 rc = rcs[i]
717 if len(chain) >= _MAX_RC or not rc.check_rc(last_rc, first_try_local):
718 i += 1
719 first_try_local = True
720 continue
722 if als_in_chain[rc.als2]:
723 i += 1
724 first_try_local = True
725 continue
727 chain.append(rc)
728 als_in_chain[rc.als2] = True
730 if len(chain) >= 3:
731 # Extract active RC candidates at each end
732 first_rc = chain[0]
733 c1 = first_rc.cand1 if first_rc.actual_rc in (1, 3) else 0
734 c2 = first_rc.cand2 if first_rc.actual_rc in (2, 3) else 0
735 # For case 3 on first_rc (both cands active), c2 set via actual_rc==3
736 # but actual_rc==3 path: c1=cand1, c2=cand2
737 if first_rc.actual_rc == 3:
738 c1 = first_rc.cand1
739 c2 = first_rc.cand2
740 elif first_rc.actual_rc == 1:
741 c1 = first_rc.cand1
742 c2 = 0
743 elif first_rc.actual_rc == 2:
744 c1 = 0
745 c2 = first_rc.cand2
746 else:
747 c1 = c2 = 0
749 c3 = c4 = 0
750 if rc.actual_rc == 1:
751 c3 = rc.cand1
752 elif rc.actual_rc == 2:
753 c3 = rc.cand2
754 elif rc.actual_rc == 3:
755 c3 = rc.cand1
756 c4 = rc.cand2
758 end_als = alses[rc.als2]
759 elims = _check_candidates_to_delete(start_als, end_als, c1, c2, c3, c4)
760 if elims:
761 step = SolutionStep(SolutionType.ALS_XY_CHAIN)
762 step.add_als(start_als.indices, start_als.candidates)
763 for link in chain:
764 step.add_als(alses[link.als2].indices,
765 alses[link.als2].candidates)
766 for cell, cand in elims:
767 step.add_candidate_to_delete(cell, cand)
768 _record_step(step, _als_index_count(step), deletes_map)
770 self._chain_recursive(
771 rc.als2, rc, True,
772 alses, rcs, start_indices, end_indices,
773 als_in_chain, chain,
774 start_als, deletes_map,
775 )
777 als_in_chain[rc.als2] = False
778 chain.pop()
780 # Doubly-linked first RC: retry with the alternate candidate
781 if last_rc is None and rc.cand2 != 0 and first_try_local:
782 first_try_local = False
783 # Don't advance i — retry same RC with first_try=False
784 else:
785 i += 1
786 first_try_local = True
788 # ------------------------------------------------------------------
789 # Death Blossom
790 # ------------------------------------------------------------------
792 def _find_death_blossom_all(self, allow_overlap: bool = False) -> list[SolutionStep]:
793 """Return ALL Death Blossom steps."""
794 grid = self.grid
795 alses = _collect_alses(grid)
796 rcdb = self._collect_rcs_for_death_blossom(alses)
797 result: list[SolutionStep] = []
799 for stem in range(81):
800 if grid.values[stem] != 0:
801 continue
802 if rcdb[stem] is None:
803 continue
804 if rcdb[stem].cand_mask != grid.candidates[stem]:
805 continue
807 max_cand = 0
808 tmp = grid.candidates[stem]
809 while tmp:
810 lsb = tmp & -tmp
811 max_cand = lsb.bit_length()
812 tmp ^= lsb
814 state = _DBState(stem)
815 self._db_recursive(
816 1, max_cand, stem, rcdb[stem], alses, grid, state, result,
817 allow_overlap=allow_overlap, find_all=True,
818 )
820 return result
822 def _find_death_blossom(self, allow_overlap: bool = False) -> SolutionStep | None:
823 """Return the FIRST Death Blossom step found (mirrors Java's onlyOne=true mode)."""
824 grid = self.grid
825 alses = _collect_alses(grid)
826 rcdb = self._collect_rcs_for_death_blossom(alses)
827 # Use a list so the recursive call can signal early exit
828 result: list[SolutionStep] = []
830 for stem in range(81):
831 if grid.values[stem] != 0:
832 continue
833 if rcdb[stem] is None:
834 continue
835 if rcdb[stem].cand_mask != grid.candidates[stem]:
836 continue
838 max_cand = 0
839 tmp = grid.candidates[stem]
840 while tmp:
841 lsb = tmp & -tmp
842 max_cand = lsb.bit_length()
843 tmp ^= lsb
845 state = _DBState(stem)
846 self._db_recursive(
847 1, max_cand, stem, rcdb[stem], alses, grid, state, result,
848 allow_overlap=allow_overlap, find_all=False,
849 )
850 if result:
851 return result[0]
853 return None
855 def _collect_rcs_for_death_blossom(
856 self, alses: list[Als],
857 ) -> list[_RCForDeathBlossom | None]:
858 """Build per-cell index of which ALSes cover each candidate."""
859 rcdb: list[_RCForDeathBlossom | None] = [None] * 81
861 for i, als in enumerate(alses):
862 tmp = als.candidates
863 while tmp:
864 lsb = tmp & -tmp
865 cand = lsb.bit_length()
866 tmp ^= lsb
867 cells = als.buddies_per_cand[cand]
868 c = cells
869 while c:
870 cl = c & -c
871 cell = cl.bit_length() - 1
872 c ^= cl
873 if rcdb[cell] is None:
874 rcdb[cell] = _RCForDeathBlossom()
875 rcdb[cell].add_als_for_candidate(i, cand)
877 return rcdb
879 def _db_recursive(
880 self,
881 cand: int,
882 max_cand: int,
883 stem: int,
884 rcdb_entry: _RCForDeathBlossom,
885 alses: list[Als],
886 grid: Grid,
887 state: _DBState,
888 result: list,
889 allow_overlap: bool = False,
890 find_all: bool = False,
891 ) -> None:
892 if cand > max_cand:
893 return
895 if rcdb_entry.als_per_candidate[cand]:
896 for als_idx in rcdb_entry.als_per_candidate[cand]:
897 if not find_all and result:
898 return # early exit
899 als = alses[als_idx]
901 # ALS must never contain the stem cell itself
902 if als.indices & (1 << stem):
903 continue
904 # No petal-petal overlap unless allow_overlap=True
905 petal_indices = state.db_indices & ~(1 << stem)
906 if not allow_overlap and (als.indices & petal_indices):
907 continue
909 # Must share at least one common candidate
910 if not (state.db_candidates & als.candidates):
911 continue
913 state.akt_db_als[cand] = als_idx
914 inc = state.db_candidates & ~als.candidates
915 state.inc_db_cand[cand] = inc
916 state.db_candidates &= als.candidates
917 state.db_indices |= als.indices
919 if cand < max_cand:
920 self._db_recursive(
921 cand + 1, max_cand, stem, rcdb_entry,
922 alses, grid, state, result,
923 allow_overlap=allow_overlap, find_all=find_all,
924 )
925 else:
926 self._db_check_eliminations(stem, alses, grid, state, result)
928 # Backtrack
929 state.db_candidates |= inc
930 state.db_indices &= ~als.indices
931 state.akt_db_als[cand] = -1
933 if not find_all and result:
934 return # early exit after finding one
935 else:
936 state.akt_db_als[cand] = -1
937 self._db_recursive(
938 cand + 1, max_cand, stem, rcdb_entry,
939 alses, grid, state, result,
940 allow_overlap=allow_overlap, find_all=find_all,
941 )
943 def _db_check_eliminations(
944 self,
945 stem: int,
946 alses: list[Als],
947 grid: Grid,
948 state: _DBState,
949 result: list,
950 ) -> None:
951 """Check for eliminations; append first valid step to result."""
952 elims: list[tuple[int, int]] = []
954 tmp = state.db_candidates
955 while tmp:
956 lsb = tmp & -tmp
957 check_cand = lsb.bit_length()
958 tmp ^= lsb
960 if state.akt_db_als[check_cand] != -1:
961 continue # stem candidate — locked, can't eliminate externally
963 union_cells = 0
964 for k in range(1, 10):
965 if state.akt_db_als[k] == -1:
966 continue
967 union_cells |= alses[state.akt_db_als[k]].indices_per_cand[check_cand]
969 if not union_cells:
970 continue
972 buddies = _get_buddies_of_set(union_cells)
973 buddies &= ~state.db_indices
974 buddies &= ~(1 << stem)
975 buddies &= grid.candidate_sets[check_cand]
977 c = buddies
978 while c:
979 cl = c & -c
980 elims.append((cl.bit_length() - 1, check_cand))
981 c ^= cl
983 if not elims:
984 return
986 step = SolutionStep(SolutionType.DEATH_BLOSSOM)
987 step.add_index(stem)
988 for k in range(1, 10):
989 if state.akt_db_als[k] == -1:
990 continue
991 als = alses[state.akt_db_als[k]]
992 step.add_als(als.indices, als.candidates)
993 for cell, cand in elims:
994 step.add_candidate_to_delete(cell, cand)
996 result.append(step)
999class _DBState:
1000 """Mutable state for the Death Blossom recursive search."""
1002 __slots__ = ('db_indices', 'db_candidates', 'akt_db_als', 'inc_db_cand')
1004 def __init__(self, stem: int) -> None:
1005 self.db_indices: int = 1 << stem # start with stem cell excluded
1006 self.db_candidates: int = 0x1ff # all 9 candidates (MAX_MASK)
1007 self.akt_db_als: list[int] = [-1] * 10
1008 self.inc_db_cand: list[int] = [0] * 10
1011# ---------------------------------------------------------------------------
1012# Pivot identification for ALS-XY-Wing
1013# ---------------------------------------------------------------------------
1015def _identify_pivot(
1016 rc1: RestrictedCommon, rc2: RestrictedCommon,
1017) -> tuple[int | None, int | None, int | None]:
1018 """Find the pivot ALS (C) shared by rc1 and rc2; return (c_idx, a_idx, b_idx).
1020 Returns (None, None, None) if no valid pivot exists (not exactly 3 distinct ALSes).
1021 """
1022 # Four possible shared-ALS configurations (mirrors Java's if-chain exactly)
1023 c_idx = a_idx = b_idx = None
1025 if rc1.als1 == rc2.als1 and rc1.als2 != rc2.als2:
1026 c_idx, a_idx, b_idx = rc1.als1, rc1.als2, rc2.als2
1027 elif rc1.als1 == rc2.als2 and rc1.als2 != rc2.als1:
1028 c_idx, a_idx, b_idx = rc1.als1, rc1.als2, rc2.als1
1029 elif rc1.als2 == rc2.als1 and rc1.als1 != rc2.als2:
1030 c_idx, a_idx, b_idx = rc1.als2, rc1.als1, rc2.als2
1031 elif rc1.als2 == rc2.als2 and rc1.als1 != rc2.als1:
1032 c_idx, a_idx, b_idx = rc1.als2, rc1.als1, rc2.als1
1034 return c_idx, a_idx, b_idx