Coverage for src / hodoku / solver / als.py: 89%

611 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 08:35 +0000

1"""ALS solver: ALS-XZ, ALS-XY-Wing, ALS-XY-Chain, Death Blossom. 

2 

3Mirrors Java's AlsSolver, Als, RestrictedCommon, and the ALS/RC collection 

4methods in SudokuStepFinder. 

5""" 

6 

7from __future__ import annotations 

8 

9from functools import cmp_to_key 

10from typing import TYPE_CHECKING 

11 

12from hodoku.core.grid import ALL_UNITS, BUDDIES, Grid 

13from hodoku.core.solution_step import SolutionStep 

14from hodoku.core.types import SolutionType 

15 

16if TYPE_CHECKING: 

17 from hodoku.config import StepSearchConfig 

18 

19_MAX_RC = 50 # maximum RCs in an ALS-Chain (matches Java MAX_RC) 

20 

21# Ordinal map used by AlsComparator for tiebreaking on type 

22_SOL_TYPE_ORDINALS = {t: i for i, t in enumerate(SolutionType)} 

23 

24 

25# --------------------------------------------------------------------------- 

26# Helpers 

27# --------------------------------------------------------------------------- 

28 

29def _get_buddies_of_set(cell_set: int) -> int: 

30 """Intersection of BUDDIES[c] for every cell c in cell_set. 

31 

32 Returns the set of all cells that see every cell in cell_set. 

33 Returns 0 if cell_set is empty. 

34 """ 

35 if not cell_set: 

36 return 0 

37 result = (1 << 81) - 1 

38 tmp = cell_set 

39 while tmp: 

40 lsb = tmp & -tmp 

41 result &= BUDDIES[lsb.bit_length() - 1] 

42 tmp ^= lsb 

43 return result 

44 

45 

46# --------------------------------------------------------------------------- 

47# ALS data structure 

48# --------------------------------------------------------------------------- 

49 

50class Als: 

51 """Almost Locked Set: N cells in one house with exactly N+1 candidates.""" 

52 

53 __slots__ = ( 

54 'indices', 'candidates', 

55 'indices_per_cand', 'buddies_per_cand', 'buddies_als_per_cand', 

56 'buddies', 

57 ) 

58 

59 def __init__(self, indices: int, candidates: int) -> None: 

60 self.indices = indices 

61 self.candidates = candidates 

62 self.indices_per_cand: list[int] = [0] * 10 

63 self.buddies_per_cand: list[int] = [0] * 10 

64 self.buddies_als_per_cand: list[int] = [0] * 10 

65 self.buddies: int = 0 

66 

67 def compute_fields(self, grid: Grid) -> None: 

68 """Compute derived sets after collection.""" 

69 cands = self.candidates 

70 tmp = cands 

71 while tmp: 

72 lsb = tmp & -tmp 

73 d = lsb.bit_length() # digit (1-9) 

74 tmp ^= lsb 

75 ipc = self.indices & grid.candidate_sets[d] 

76 self.indices_per_cand[d] = ipc 

77 bpc = _get_buddies_of_set(ipc) & ~self.indices & grid.candidate_sets[d] 

78 self.buddies_per_cand[d] = bpc 

79 self.buddies_als_per_cand[d] = bpc | ipc 

80 self.buddies |= bpc 

81 

82 def get_chain_penalty(self) -> int: 

83 """Chain distance penalty for using this ALS in a tabling chain. 

84 

85 Mirrors Als.getChainPenalty() in Java. 

86 """ 

87 cand_size = self.candidates.bit_count() 

88 if cand_size <= 1: 

89 return 0 

90 if cand_size == 2: 

91 return 1 

92 return (cand_size - 1) * 2 

93 

94 def __eq__(self, other: object) -> bool: 

95 return isinstance(other, Als) and self.indices == other.indices 

96 

97 def __hash__(self) -> int: 

98 return hash(self.indices) 

99 

100 

101# --------------------------------------------------------------------------- 

102# Restricted Common data structure 

103# --------------------------------------------------------------------------- 

104 

105class RestrictedCommon: 

106 """A Restricted Common between two ALSes. 

107 

108 actual_rc encoding: 0=none, 1=cand1 only, 2=cand2 only, 3=both. 

109 """ 

110 

111 __slots__ = ('als1', 'als2', 'cand1', 'cand2', 'actual_rc') 

112 

113 def __init__(self, als1: int, als2: int, cand1: int, 

114 cand2: int = 0, actual_rc: int = 0) -> None: 

115 self.als1 = als1 

116 self.als2 = als2 

117 self.cand1 = cand1 

118 self.cand2 = cand2 

119 self.actual_rc = actual_rc 

120 

121 def check_rc(self, prev: RestrictedCommon | None, first_try: bool) -> bool: 

122 """Update actual_rc based on the previous RC; return True if valid.""" 

123 self.actual_rc = 1 if self.cand2 == 0 else 3 

124 if prev is None: 

125 if self.cand2 != 0: 

126 self.actual_rc = 1 if first_try else 2 

127 return self.actual_rc != 0 

128 if prev.actual_rc == 1: 

129 self.actual_rc = _check_rc_int(prev.cand1, 0, self.cand1, self.cand2) 

130 elif prev.actual_rc == 2: 

131 self.actual_rc = _check_rc_int(prev.cand2, 0, self.cand1, self.cand2) 

132 elif prev.actual_rc == 3: 

133 # Java passes cand1 twice here (case 3) — replicated faithfully. 

134 self.actual_rc = _check_rc_int(prev.cand1, prev.cand1, self.cand1, self.cand2) 

135 return self.actual_rc != 0 

136 

137 

138def _check_rc_int(c11: int, c12: int, c21: int, c22: int) -> int: 

139 """Remove ARC candidates {c11, c12} from PRC candidates {c21, c22}. 

140 

141 Returns 0 (none left), 1 (c21), 2 (c22), or 3 (both). 

142 """ 

143 if c12 == 0: 

144 # one ARC 

145 if c22 == 0: 

146 return 0 if c11 == c21 else 1 

147 else: 

148 if c11 == c22: 

149 return 1 

150 elif c11 == c21: 

151 return 2 

152 else: 

153 return 3 

154 else: 

155 # two ARCs 

156 if c22 == 0: 

157 return 0 if (c11 == c21 or c12 == c21) else 1 

158 else: 

159 if (c11 == c21 and c12 == c22) or (c11 == c22 and c12 == c21): 

160 return 0 

161 elif c11 == c22 or c12 == c22: 

162 return 1 

163 elif c11 == c21 or c12 == c21: 

164 return 2 

165 else: 

166 return 3 

167 

168 

169# --------------------------------------------------------------------------- 

170# ALS collection 

171# --------------------------------------------------------------------------- 

172 

173def _collect_alses(grid: Grid) -> list[Als]: 

174 """Enumerate all ALSes in the grid (including single bivalue cells). 

175 

176 Mirrors SudokuStepFinder.doGetAlses(onlyLargerThanOne=false). 

177 Iteration: 27 units x 9 start positions each, recursive subset search. 

178 """ 

179 alses: list[Als] = [] 

180 seen: set[int] = set() 

181 

182 for unit in ALL_UNITS: 

183 n = len(unit) 

184 for start_j in range(n): 

185 _check_als_recursive(0, start_j, unit, n, grid, alses, seen, 0, 0) 

186 

187 for als in alses: 

188 als.compute_fields(grid) 

189 

190 return alses 

191 

192 

193def _check_als_recursive( 

194 anzahl: int, 

195 start_idx: int, 

196 unit: tuple[int, ...], 

197 n: int, 

198 grid: Grid, 

199 alses: list[Als], 

200 seen: set[int], 

201 index_set: int, 

202 cand_acc: int, 

203) -> None: 

204 """Recursive ALS search over one house. 

205 

206 anzahl: number of cells already in the current set (0-based on entry). 

207 """ 

208 anzahl += 1 

209 if anzahl > n - 1: 

210 return 

211 for i in range(start_idx, n): 

212 cell = unit[i] 

213 if grid.values[cell] != 0: 

214 continue 

215 new_index_set = index_set | (1 << cell) 

216 new_cands = cand_acc | grid.candidates[cell] 

217 if new_cands.bit_count() - anzahl == 1: 

218 if new_index_set not in seen: 

219 seen.add(new_index_set) 

220 alses.append(Als(new_index_set, new_cands)) 

221 _check_als_recursive(anzahl, i + 1, unit, n, grid, alses, seen, 

222 new_index_set, new_cands) 

223 

224 

225# --------------------------------------------------------------------------- 

226# RC collection (forward-only, no overlap) 

227# --------------------------------------------------------------------------- 

228 

229def _collect_rcs( 

230 alses: list[Als], 

231 allow_overlap: bool = False, 

232) -> tuple[list[RestrictedCommon], list[int], list[int]]: 

233 """Find all Restricted Commons between ALS pairs. 

234 

235 Forward-only (als2 index > als1 index). 

236 When allow_overlap is False (default), overlapping ALS pairs are skipped. 

237 When allow_overlap is True, overlapping pairs are allowed provided the RC 

238 candidate does not appear in the overlap region (mirrors Java withOverlap). 

239 

240 Returns (rcs, start_indices, end_indices) where start_indices[i]..end_indices[i] 

241 is the slice of rcs whose als1 == i. 

242 """ 

243 rcs: list[RestrictedCommon] = [] 

244 n = len(alses) 

245 start_indices = [0] * n 

246 end_indices = [0] * n 

247 

248 for i in range(n): 

249 als1 = alses[i] 

250 start_indices[i] = len(rcs) 

251 for j in range(i + 1, n): 

252 als2 = alses[j] 

253 overlap = als1.indices & als2.indices 

254 if overlap and not allow_overlap: 

255 continue 

256 # Must share at least one candidate 

257 common = als1.candidates & als2.candidates 

258 if not common: 

259 continue 

260 rc_count = 0 

261 new_rc: RestrictedCommon | None = None 

262 tmp = common 

263 while tmp: 

264 lsb = tmp & -tmp 

265 cand = lsb.bit_length() # digit 

266 tmp ^= lsb 

267 all_cand_cells = als1.indices_per_cand[cand] | als2.indices_per_cand[cand] 

268 # RC candidate must not appear in the overlap region 

269 if overlap and all_cand_cells & overlap: 

270 continue 

271 # RC check: all instances of cand in both ALSes must see each other. 

272 common_buddies = (als1.buddies_als_per_cand[cand] 

273 & als2.buddies_als_per_cand[cand]) 

274 if all_cand_cells & ~common_buddies: 

275 continue # some cell doesn't see all others 

276 if rc_count == 0: 

277 new_rc = RestrictedCommon(i, j, cand) 

278 rcs.append(new_rc) 

279 else: 

280 assert new_rc is not None 

281 new_rc.cand2 = cand 

282 rc_count += 1 

283 if rc_count == 2: 

284 break # max 2 RCs per pair 

285 end_indices[i] = len(rcs) 

286 

287 return rcs, start_indices, end_indices 

288 

289 

290# --------------------------------------------------------------------------- 

291# Shared elimination helper 

292# --------------------------------------------------------------------------- 

293 

294def _check_candidates_to_delete( 

295 als1: Als, als2: Als, 

296 r1: int = 0, r2: int = 0, r3: int = 0, r4: int = 0, 

297) -> list[tuple[int, int]]: 

298 """Find (cell, digit) eliminations common to als1 and als2, minus RC digits. 

299 

300 r1..r4 are RC candidates to exclude (0 means unused). 

301 Returns list of (cell_index, digit) in ascending cell order. 

302 """ 

303 elim_mask = als1.candidates & als2.candidates 

304 for r in (r1, r2, r3, r4): 

305 if r: 

306 elim_mask &= ~(1 << (r - 1)) 

307 if not elim_mask: 

308 return [] 

309 # Quick pre-check: common buddies must exist 

310 if not (als1.buddies & als2.buddies): 

311 return [] 

312 result: list[tuple[int, int]] = [] 

313 tmp = elim_mask 

314 while tmp: 

315 lsb = tmp & -tmp 

316 cand = lsb.bit_length() 

317 tmp ^= lsb 

318 elim_cells = als1.buddies_per_cand[cand] & als2.buddies_per_cand[cand] 

319 c = elim_cells 

320 while c: 

321 cl = c & -c 

322 result.append((cl.bit_length() - 1, cand)) 

323 c ^= cl 

324 return result 

325 

326 

327def _check_doubly_linked_als( 

328 als1: Als, als2: Als, rc1: int, rc2: int, 

329) -> list[tuple[int, int]]: 

330 """Locked-set eliminations when als1 and als2 share two RCs. 

331 

332 als1 minus {rc1, rc2} forms a locked set; eliminate its remaining 

333 candidates from cells outside als2 that see all corresponding als1 cells. 

334 """ 

335 remaining = als1.candidates & ~(1 << (rc1 - 1)) & ~(1 << (rc2 - 1)) 

336 if not remaining: 

337 return [] 

338 result: list[tuple[int, int]] = [] 

339 tmp = remaining 

340 while tmp: 

341 lsb = tmp & -tmp 

342 cand = lsb.bit_length() 

343 tmp ^= lsb 

344 elim_cells = als1.buddies_per_cand[cand] & ~als2.indices 

345 c = elim_cells 

346 while c: 

347 cl = c & -c 

348 result.append((cl.bit_length() - 1, cand)) 

349 c ^= cl 

350 return result 

351 

352 

353# --------------------------------------------------------------------------- 

354# AlsComparator sort key (mirrors Java's AlsComparator) 

355# --------------------------------------------------------------------------- 

356 

357def _als_index_count(step: SolutionStep) -> int: 

358 return sum(bin(a[0]).count('1') for a in step.alses) 

359 

360 

361def _als_index_summe(candidates: list) -> int: 

362 """Weighted sum matching Java's getIndexSumme(). 

363 

364 Java sorts candidatesToDelete by (value, index) before computing: 

365 sum += index * offset + value, with offset starting at 1, incrementing by 80. 

366 """ 

367 total = 0 

368 offset = 1 

369 for c in sorted(candidates, key=lambda c: (c.value, c.index)): 

370 total += c.index * offset + c.value 

371 offset += 80 

372 return total 

373 

374 

375def _als_cmp(s1: SolutionStep, s2: SolutionStep) -> int: 

376 # 1. Most eliminations (descending) 

377 d = len(s2.candidates_to_delete) - len(s1.candidates_to_delete) 

378 if d: 

379 return d 

380 # 2. Equivalence check: same elimination set? 

381 k1 = tuple(sorted((c.index, c.value) for c in s1.candidates_to_delete)) 

382 k2 = tuple(sorted((c.index, c.value) for c in s2.candidates_to_delete)) 

383 if k1 != k2: 

384 # Not equivalent: sort by weighted index sum (ascending), matching Java's getIndexSumme 

385 return (_als_index_summe(s1.candidates_to_delete) 

386 - _als_index_summe(s2.candidates_to_delete)) 

387 # Equivalent: 3. Fewer ALSes 

388 d = len(s1.alses) - len(s2.alses) 

389 if d: 

390 return d 

391 # 4. Fewer total cells across all ALSes 

392 d = _als_index_count(s1) - _als_index_count(s2) 

393 if d: 

394 return d 

395 # 5. Type ordinal (ascending) 

396 return _SOL_TYPE_ORDINALS[s1.type] - _SOL_TYPE_ORDINALS[s2.type] 

397 

398 

399def _best_step(deletes_map: dict) -> SolutionStep | None: 

400 steps = [step for _, step in deletes_map.values()] 

401 if not steps: 

402 return None 

403 steps.sort(key=cmp_to_key(_als_cmp)) 

404 return steps[0] 

405 

406 

407def _record_step( 

408 step: SolutionStep, 

409 als_count: int, 

410 deletes_map: dict, 

411) -> None: 

412 """Store step if it's shortest (fewest ALS cells) for its elimination set.""" 

413 key = tuple(sorted((c.index, c.value) for c in step.candidates_to_delete)) 

414 old = deletes_map.get(key) 

415 if old is None or old[0] > als_count: 

416 deletes_map[key] = (als_count, step) 

417 

418 

419# --------------------------------------------------------------------------- 

420# Death Blossom RC-per-cell index 

421# --------------------------------------------------------------------------- 

422 

423class _RCForDeathBlossom: 

424 """Per-stem-cell index of which ALSes cover each candidate.""" 

425 

426 __slots__ = ('cand_mask', 'als_per_candidate') 

427 

428 def __init__(self) -> None: 

429 self.cand_mask: int = 0 

430 self.als_per_candidate: list[list[int]] = [[] for _ in range(10)] 

431 

432 def add_als_for_candidate(self, als_idx: int, cand: int) -> None: 

433 self.als_per_candidate[cand].append(als_idx) 

434 self.cand_mask |= 1 << (cand - 1) 

435 

436 

437# --------------------------------------------------------------------------- 

438# Solver 

439# --------------------------------------------------------------------------- 

440 

441class AlsSolver: 

442 """ALS-XZ, ALS-XY-Wing, ALS-XY-Chain, Death Blossom.""" 

443 

444 def __init__(self, grid: Grid, search_config: StepSearchConfig | None = None) -> None: 

445 self.grid = grid 

446 if search_config is not None: 

447 self._allow_overlap = search_config.als_allow_overlap 

448 else: 

449 self._allow_overlap = False 

450 

451 def get_step( 

452 self, sol_type: SolutionType, allow_overlap: bool | None = None 

453 ) -> SolutionStep | None: 

454 if allow_overlap is None: 

455 allow_overlap = self._allow_overlap 

456 if sol_type == SolutionType.ALS_XZ: 

457 return self._find_als_xz() 

458 if sol_type == SolutionType.ALS_XY_WING: 

459 return self._find_als_xy_wing(allow_overlap=allow_overlap) 

460 if sol_type == SolutionType.ALS_XY_CHAIN: 

461 return self._find_als_xy_chain(allow_overlap=allow_overlap) 

462 if sol_type == SolutionType.DEATH_BLOSSOM: 

463 return self._find_death_blossom(allow_overlap=allow_overlap) 

464 return None 

465 

466 def find_all( 

467 self, sol_type: SolutionType, allow_overlap: bool | None = None 

468 ) -> list[SolutionStep]: 

469 if allow_overlap is None: 

470 allow_overlap = self._allow_overlap 

471 if sol_type == SolutionType.ALS_XZ: 

472 return self._find_als_xz_all() 

473 if sol_type == SolutionType.ALS_XY_WING: 

474 return self._find_als_xy_wing_all(allow_overlap=allow_overlap) 

475 if sol_type == SolutionType.ALS_XY_CHAIN: 

476 return self._find_als_xy_chain_all(allow_overlap=allow_overlap) 

477 if sol_type == SolutionType.DEATH_BLOSSOM: 

478 return self._find_death_blossom_all(allow_overlap=allow_overlap) 

479 return [] 

480 

481 # ------------------------------------------------------------------ 

482 # ALS-XZ 

483 # ------------------------------------------------------------------ 

484 

485 def _find_als_xz(self) -> SolutionStep | None: 

486 """Return the FIRST ALS-XZ step found (mirrors Java's onlyOne=true mode).""" 

487 grid = self.grid 

488 alses = _collect_alses(grid) 

489 rcs, _, _ = _collect_rcs(alses) 

490 

491 for rc in rcs: 

492 if rc.als1 >= rc.als2: 

493 continue # forward only (always true in forward-only mode) 

494 als1 = alses[rc.als1] 

495 als2 = alses[rc.als2] 

496 

497 elims: list[tuple[int, int]] = [] 

498 

499 # Singly-linked elimination: exclude rc.cand1 

500 elims.extend(_check_candidates_to_delete(als1, als2, rc.cand1)) 

501 

502 # Doubly-linked: also exclude rc.cand2, plus locked-set elims 

503 if rc.cand2: 

504 elims.extend(_check_candidates_to_delete(als1, als2, rc.cand2)) 

505 elims.extend(_check_doubly_linked_als(als1, als2, rc.cand1, rc.cand2)) 

506 elims.extend(_check_doubly_linked_als(als2, als1, rc.cand1, rc.cand2)) 

507 

508 if not elims: 

509 continue 

510 

511 step = SolutionStep(SolutionType.ALS_XZ) 

512 step.add_als(als1.indices, als1.candidates) 

513 step.add_als(als2.indices, als2.candidates) 

514 for cell, cand in elims: 

515 step.add_candidate_to_delete(cell, cand) 

516 return step 

517 

518 return None 

519 

520 def _find_als_xz_all(self) -> list[SolutionStep]: 

521 """Return ALL ALS-XZ steps (collect all, deduplicate by elimination set).""" 

522 grid = self.grid 

523 alses = _collect_alses(grid) 

524 rcs, _, _ = _collect_rcs(alses) 

525 deletes_map: dict = {} 

526 

527 for rc in rcs: 

528 if rc.als1 >= rc.als2: 

529 continue 

530 als1 = alses[rc.als1] 

531 als2 = alses[rc.als2] 

532 

533 elims: list[tuple[int, int]] = [] 

534 elims.extend(_check_candidates_to_delete(als1, als2, rc.cand1)) 

535 if rc.cand2: 

536 elims.extend(_check_candidates_to_delete(als1, als2, rc.cand2)) 

537 elims.extend(_check_doubly_linked_als(als1, als2, rc.cand1, rc.cand2)) 

538 elims.extend(_check_doubly_linked_als(als2, als1, rc.cand1, rc.cand2)) 

539 if not elims: 

540 continue 

541 

542 step = SolutionStep(SolutionType.ALS_XZ) 

543 step.add_als(als1.indices, als1.candidates) 

544 step.add_als(als2.indices, als2.candidates) 

545 for cell, cand in elims: 

546 step.add_candidate_to_delete(cell, cand) 

547 _record_step(step, _als_index_count(step), deletes_map) 

548 

549 return [step for _, step in deletes_map.values()] 

550 

551 # ------------------------------------------------------------------ 

552 # ALS-XY-Wing 

553 # ------------------------------------------------------------------ 

554 

555 def _find_als_xy_wing(self, allow_overlap: bool = False) -> SolutionStep | None: 

556 """Return the FIRST ALS-XY-Wing step found (mirrors Java's onlyOne=true mode).""" 

557 grid = self.grid 

558 alses = _collect_alses(grid) 

559 rcs, _, _ = _collect_rcs(alses, allow_overlap=allow_overlap) 

560 n_rcs = len(rcs) 

561 

562 for i in range(n_rcs): 

563 rc1 = rcs[i] 

564 for j in range(i + 1, n_rcs): 

565 rc2 = rcs[j] 

566 

567 # Both singly-linked with the same candidate → skip 

568 if rc1.cand2 == 0 and rc2.cand2 == 0 and rc1.cand1 == rc2.cand1: 

569 continue 

570 

571 # Find pivot C: the ALS shared by both RCs 

572 c_idx, a_idx, b_idx = _identify_pivot(rc1, rc2) 

573 if c_idx is None: 

574 continue 

575 

576 als_a = alses[a_idx] 

577 als_b = alses[b_idx] 

578 

579 # A and B must not overlap (skip when allow_overlap=True) 

580 if not allow_overlap and (als_a.indices & als_b.indices): 

581 continue 

582 # A must not be subset/superset of B 

583 union_ab = als_a.indices | als_b.indices 

584 if union_ab == als_a.indices or union_ab == als_b.indices: 

585 continue 

586 

587 elims = _check_candidates_to_delete( 

588 als_a, als_b, 

589 rc1.cand1, rc1.cand2, rc2.cand1, rc2.cand2, 

590 ) 

591 if not elims: 

592 continue 

593 

594 step = SolutionStep(SolutionType.ALS_XY_WING) 

595 step.add_als(als_a.indices, als_a.candidates) 

596 step.add_als(als_b.indices, als_b.candidates) 

597 step.add_als(alses[c_idx].indices, alses[c_idx].candidates) 

598 for cell, cand in elims: 

599 step.add_candidate_to_delete(cell, cand) 

600 return step 

601 

602 return None 

603 

604 def _find_als_xy_wing_all(self, allow_overlap: bool = False) -> list[SolutionStep]: 

605 """Return ALL ALS-XY-Wing steps (collect all, deduplicate by elimination set).""" 

606 grid = self.grid 

607 alses = _collect_alses(grid) 

608 rcs, _, _ = _collect_rcs(alses, allow_overlap=allow_overlap) 

609 n_rcs = len(rcs) 

610 deletes_map: dict = {} 

611 

612 for i in range(n_rcs): 

613 rc1 = rcs[i] 

614 for j in range(i + 1, n_rcs): 

615 rc2 = rcs[j] 

616 if rc1.cand2 == 0 and rc2.cand2 == 0 and rc1.cand1 == rc2.cand1: 

617 continue 

618 c_idx, a_idx, b_idx = _identify_pivot(rc1, rc2) 

619 if c_idx is None: 

620 continue 

621 als_a = alses[a_idx] 

622 als_b = alses[b_idx] 

623 if not allow_overlap and (als_a.indices & als_b.indices): 

624 continue 

625 union_ab = als_a.indices | als_b.indices 

626 if union_ab == als_a.indices or union_ab == als_b.indices: 

627 continue 

628 elims = _check_candidates_to_delete( 

629 als_a, als_b, 

630 rc1.cand1, rc1.cand2, rc2.cand1, rc2.cand2, 

631 ) 

632 if not elims: 

633 continue 

634 step = SolutionStep(SolutionType.ALS_XY_WING) 

635 step.add_als(als_a.indices, als_a.candidates) 

636 step.add_als(als_b.indices, als_b.candidates) 

637 step.add_als(alses[c_idx].indices, alses[c_idx].candidates) 

638 for cell, cand in elims: 

639 step.add_candidate_to_delete(cell, cand) 

640 _record_step(step, _als_index_count(step), deletes_map) 

641 

642 return [step for _, step in deletes_map.values()] 

643 

644 # ------------------------------------------------------------------ 

645 # ALS-XY-Chain 

646 # ------------------------------------------------------------------ 

647 

648 def _find_als_xy_chain(self, allow_overlap: bool = False) -> SolutionStep | None: 

649 grid = self.grid 

650 alses = _collect_alses(grid) 

651 rcs, start_indices, end_indices = _collect_rcs(alses, allow_overlap=allow_overlap) 

652 deletes_map: dict = {} 

653 

654 n_als = len(alses) 

655 als_in_chain = [False] * n_als 

656 chain: list[RestrictedCommon] = [] 

657 

658 for i in range(n_als): 

659 start_als = alses[i] 

660 als_in_chain[i] = True 

661 self._chain_recursive( 

662 i, None, True, 

663 alses, rcs, start_indices, end_indices, 

664 als_in_chain, chain, 

665 start_als, deletes_map, 

666 ) 

667 als_in_chain[i] = False 

668 

669 return _best_step(deletes_map) 

670 

671 def _find_als_xy_chain_all(self, allow_overlap: bool = False) -> list[SolutionStep]: 

672 """Return ALL ALS-XY-Chain steps (all entries from the deduplication map).""" 

673 grid = self.grid 

674 alses = _collect_alses(grid) 

675 rcs, start_indices, end_indices = _collect_rcs(alses, allow_overlap=allow_overlap) 

676 deletes_map: dict = {} 

677 

678 n_als = len(alses) 

679 als_in_chain = [False] * n_als 

680 chain: list[RestrictedCommon] = [] 

681 

682 for i in range(n_als): 

683 start_als = alses[i] 

684 als_in_chain[i] = True 

685 self._chain_recursive( 

686 i, None, True, 

687 alses, rcs, start_indices, end_indices, 

688 als_in_chain, chain, 

689 start_als, deletes_map, 

690 ) 

691 als_in_chain[i] = False 

692 

693 return [step for _, step in deletes_map.values()] 

694 

695 def _chain_recursive( 

696 self, 

697 als_idx: int, 

698 last_rc: RestrictedCommon | None, 

699 first_try: bool, 

700 alses: list[Als], 

701 rcs: list[RestrictedCommon], 

702 start_indices: list[int], 

703 end_indices: list[int], 

704 als_in_chain: list[bool], 

705 chain: list[RestrictedCommon], 

706 start_als: Als, 

707 deletes_map: dict, 

708 ) -> None: 

709 if len(chain) >= _MAX_RC: 

710 return 

711 

712 first_try_local = True 

713 i = start_indices[als_idx] 

714 while i < end_indices[als_idx]: 

715 rc = rcs[i] 

716 

717 if len(chain) >= _MAX_RC or not rc.check_rc(last_rc, first_try_local): 

718 i += 1 

719 first_try_local = True 

720 continue 

721 

722 if als_in_chain[rc.als2]: 

723 i += 1 

724 first_try_local = True 

725 continue 

726 

727 chain.append(rc) 

728 als_in_chain[rc.als2] = True 

729 

730 if len(chain) >= 3: 

731 # Extract active RC candidates at each end 

732 first_rc = chain[0] 

733 c1 = first_rc.cand1 if first_rc.actual_rc in (1, 3) else 0 

734 c2 = first_rc.cand2 if first_rc.actual_rc in (2, 3) else 0 

735 # For case 3 on first_rc (both cands active), c2 set via actual_rc==3 

736 # but actual_rc==3 path: c1=cand1, c2=cand2 

737 if first_rc.actual_rc == 3: 

738 c1 = first_rc.cand1 

739 c2 = first_rc.cand2 

740 elif first_rc.actual_rc == 1: 

741 c1 = first_rc.cand1 

742 c2 = 0 

743 elif first_rc.actual_rc == 2: 

744 c1 = 0 

745 c2 = first_rc.cand2 

746 else: 

747 c1 = c2 = 0 

748 

749 c3 = c4 = 0 

750 if rc.actual_rc == 1: 

751 c3 = rc.cand1 

752 elif rc.actual_rc == 2: 

753 c3 = rc.cand2 

754 elif rc.actual_rc == 3: 

755 c3 = rc.cand1 

756 c4 = rc.cand2 

757 

758 end_als = alses[rc.als2] 

759 elims = _check_candidates_to_delete(start_als, end_als, c1, c2, c3, c4) 

760 if elims: 

761 step = SolutionStep(SolutionType.ALS_XY_CHAIN) 

762 step.add_als(start_als.indices, start_als.candidates) 

763 for link in chain: 

764 step.add_als(alses[link.als2].indices, 

765 alses[link.als2].candidates) 

766 for cell, cand in elims: 

767 step.add_candidate_to_delete(cell, cand) 

768 _record_step(step, _als_index_count(step), deletes_map) 

769 

770 self._chain_recursive( 

771 rc.als2, rc, True, 

772 alses, rcs, start_indices, end_indices, 

773 als_in_chain, chain, 

774 start_als, deletes_map, 

775 ) 

776 

777 als_in_chain[rc.als2] = False 

778 chain.pop() 

779 

780 # Doubly-linked first RC: retry with the alternate candidate 

781 if last_rc is None and rc.cand2 != 0 and first_try_local: 

782 first_try_local = False 

783 # Don't advance i — retry same RC with first_try=False 

784 else: 

785 i += 1 

786 first_try_local = True 

787 

788 # ------------------------------------------------------------------ 

789 # Death Blossom 

790 # ------------------------------------------------------------------ 

791 

792 def _find_death_blossom_all(self, allow_overlap: bool = False) -> list[SolutionStep]: 

793 """Return ALL Death Blossom steps.""" 

794 grid = self.grid 

795 alses = _collect_alses(grid) 

796 rcdb = self._collect_rcs_for_death_blossom(alses) 

797 result: list[SolutionStep] = [] 

798 

799 for stem in range(81): 

800 if grid.values[stem] != 0: 

801 continue 

802 if rcdb[stem] is None: 

803 continue 

804 if rcdb[stem].cand_mask != grid.candidates[stem]: 

805 continue 

806 

807 max_cand = 0 

808 tmp = grid.candidates[stem] 

809 while tmp: 

810 lsb = tmp & -tmp 

811 max_cand = lsb.bit_length() 

812 tmp ^= lsb 

813 

814 state = _DBState(stem) 

815 self._db_recursive( 

816 1, max_cand, stem, rcdb[stem], alses, grid, state, result, 

817 allow_overlap=allow_overlap, find_all=True, 

818 ) 

819 

820 return result 

821 

822 def _find_death_blossom(self, allow_overlap: bool = False) -> SolutionStep | None: 

823 """Return the FIRST Death Blossom step found (mirrors Java's onlyOne=true mode).""" 

824 grid = self.grid 

825 alses = _collect_alses(grid) 

826 rcdb = self._collect_rcs_for_death_blossom(alses) 

827 # Use a list so the recursive call can signal early exit 

828 result: list[SolutionStep] = [] 

829 

830 for stem in range(81): 

831 if grid.values[stem] != 0: 

832 continue 

833 if rcdb[stem] is None: 

834 continue 

835 if rcdb[stem].cand_mask != grid.candidates[stem]: 

836 continue 

837 

838 max_cand = 0 

839 tmp = grid.candidates[stem] 

840 while tmp: 

841 lsb = tmp & -tmp 

842 max_cand = lsb.bit_length() 

843 tmp ^= lsb 

844 

845 state = _DBState(stem) 

846 self._db_recursive( 

847 1, max_cand, stem, rcdb[stem], alses, grid, state, result, 

848 allow_overlap=allow_overlap, find_all=False, 

849 ) 

850 if result: 

851 return result[0] 

852 

853 return None 

854 

855 def _collect_rcs_for_death_blossom( 

856 self, alses: list[Als], 

857 ) -> list[_RCForDeathBlossom | None]: 

858 """Build per-cell index of which ALSes cover each candidate.""" 

859 rcdb: list[_RCForDeathBlossom | None] = [None] * 81 

860 

861 for i, als in enumerate(alses): 

862 tmp = als.candidates 

863 while tmp: 

864 lsb = tmp & -tmp 

865 cand = lsb.bit_length() 

866 tmp ^= lsb 

867 cells = als.buddies_per_cand[cand] 

868 c = cells 

869 while c: 

870 cl = c & -c 

871 cell = cl.bit_length() - 1 

872 c ^= cl 

873 if rcdb[cell] is None: 

874 rcdb[cell] = _RCForDeathBlossom() 

875 rcdb[cell].add_als_for_candidate(i, cand) 

876 

877 return rcdb 

878 

879 def _db_recursive( 

880 self, 

881 cand: int, 

882 max_cand: int, 

883 stem: int, 

884 rcdb_entry: _RCForDeathBlossom, 

885 alses: list[Als], 

886 grid: Grid, 

887 state: _DBState, 

888 result: list, 

889 allow_overlap: bool = False, 

890 find_all: bool = False, 

891 ) -> None: 

892 if cand > max_cand: 

893 return 

894 

895 if rcdb_entry.als_per_candidate[cand]: 

896 for als_idx in rcdb_entry.als_per_candidate[cand]: 

897 if not find_all and result: 

898 return # early exit 

899 als = alses[als_idx] 

900 

901 # ALS must never contain the stem cell itself 

902 if als.indices & (1 << stem): 

903 continue 

904 # No petal-petal overlap unless allow_overlap=True 

905 petal_indices = state.db_indices & ~(1 << stem) 

906 if not allow_overlap and (als.indices & petal_indices): 

907 continue 

908 

909 # Must share at least one common candidate 

910 if not (state.db_candidates & als.candidates): 

911 continue 

912 

913 state.akt_db_als[cand] = als_idx 

914 inc = state.db_candidates & ~als.candidates 

915 state.inc_db_cand[cand] = inc 

916 state.db_candidates &= als.candidates 

917 state.db_indices |= als.indices 

918 

919 if cand < max_cand: 

920 self._db_recursive( 

921 cand + 1, max_cand, stem, rcdb_entry, 

922 alses, grid, state, result, 

923 allow_overlap=allow_overlap, find_all=find_all, 

924 ) 

925 else: 

926 self._db_check_eliminations(stem, alses, grid, state, result) 

927 

928 # Backtrack 

929 state.db_candidates |= inc 

930 state.db_indices &= ~als.indices 

931 state.akt_db_als[cand] = -1 

932 

933 if not find_all and result: 

934 return # early exit after finding one 

935 else: 

936 state.akt_db_als[cand] = -1 

937 self._db_recursive( 

938 cand + 1, max_cand, stem, rcdb_entry, 

939 alses, grid, state, result, 

940 allow_overlap=allow_overlap, find_all=find_all, 

941 ) 

942 

943 def _db_check_eliminations( 

944 self, 

945 stem: int, 

946 alses: list[Als], 

947 grid: Grid, 

948 state: _DBState, 

949 result: list, 

950 ) -> None: 

951 """Check for eliminations; append first valid step to result.""" 

952 elims: list[tuple[int, int]] = [] 

953 

954 tmp = state.db_candidates 

955 while tmp: 

956 lsb = tmp & -tmp 

957 check_cand = lsb.bit_length() 

958 tmp ^= lsb 

959 

960 if state.akt_db_als[check_cand] != -1: 

961 continue # stem candidate — locked, can't eliminate externally 

962 

963 union_cells = 0 

964 for k in range(1, 10): 

965 if state.akt_db_als[k] == -1: 

966 continue 

967 union_cells |= alses[state.akt_db_als[k]].indices_per_cand[check_cand] 

968 

969 if not union_cells: 

970 continue 

971 

972 buddies = _get_buddies_of_set(union_cells) 

973 buddies &= ~state.db_indices 

974 buddies &= ~(1 << stem) 

975 buddies &= grid.candidate_sets[check_cand] 

976 

977 c = buddies 

978 while c: 

979 cl = c & -c 

980 elims.append((cl.bit_length() - 1, check_cand)) 

981 c ^= cl 

982 

983 if not elims: 

984 return 

985 

986 step = SolutionStep(SolutionType.DEATH_BLOSSOM) 

987 step.add_index(stem) 

988 for k in range(1, 10): 

989 if state.akt_db_als[k] == -1: 

990 continue 

991 als = alses[state.akt_db_als[k]] 

992 step.add_als(als.indices, als.candidates) 

993 for cell, cand in elims: 

994 step.add_candidate_to_delete(cell, cand) 

995 

996 result.append(step) 

997 

998 

999class _DBState: 

1000 """Mutable state for the Death Blossom recursive search.""" 

1001 

1002 __slots__ = ('db_indices', 'db_candidates', 'akt_db_als', 'inc_db_cand') 

1003 

1004 def __init__(self, stem: int) -> None: 

1005 self.db_indices: int = 1 << stem # start with stem cell excluded 

1006 self.db_candidates: int = 0x1ff # all 9 candidates (MAX_MASK) 

1007 self.akt_db_als: list[int] = [-1] * 10 

1008 self.inc_db_cand: list[int] = [0] * 10 

1009 

1010 

1011# --------------------------------------------------------------------------- 

1012# Pivot identification for ALS-XY-Wing 

1013# --------------------------------------------------------------------------- 

1014 

1015def _identify_pivot( 

1016 rc1: RestrictedCommon, rc2: RestrictedCommon, 

1017) -> tuple[int | None, int | None, int | None]: 

1018 """Find the pivot ALS (C) shared by rc1 and rc2; return (c_idx, a_idx, b_idx). 

1019 

1020 Returns (None, None, None) if no valid pivot exists (not exactly 3 distinct ALSes). 

1021 """ 

1022 # Four possible shared-ALS configurations (mirrors Java's if-chain exactly) 

1023 c_idx = a_idx = b_idx = None 

1024 

1025 if rc1.als1 == rc2.als1 and rc1.als2 != rc2.als2: 

1026 c_idx, a_idx, b_idx = rc1.als1, rc1.als2, rc2.als2 

1027 elif rc1.als1 == rc2.als2 and rc1.als2 != rc2.als1: 

1028 c_idx, a_idx, b_idx = rc1.als1, rc1.als2, rc2.als1 

1029 elif rc1.als2 == rc2.als1 and rc1.als1 != rc2.als2: 

1030 c_idx, a_idx, b_idx = rc1.als2, rc1.als1, rc2.als2 

1031 elif rc1.als2 == rc2.als2 and rc1.als1 != rc2.als1: 

1032 c_idx, a_idx, b_idx = rc1.als2, rc1.als1, rc2.als1 

1033 

1034 return c_idx, a_idx, b_idx