Coverage for src / hodoku / solver / chains.py: 98%

302 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 08:35 +0000

1"""Chain solver: X-Chain, XY-Chain, Remote Pair, Turbot Fish. 

2 

3Nice Loop / AIC / Grouped Nice Loop are handled by tabling.py (TablingSolver), 

4not here. This module mirrors only the portions of Java's ChainSolver that are 

5actually wired up in step_finder.py. 

6""" 

7 

8from __future__ import annotations 

9 

10from typing import TYPE_CHECKING 

11 

12from hodoku.core.grid import ( 

13 ALL_UNITS, BUDDIES, CELL_CONSTRAINTS, 

14 CONSTRAINTS, Grid, 

15) 

16from hodoku.core.solution_step import SolutionStep 

17from hodoku.core.types import SolutionType 

18 

19if TYPE_CHECKING: 

20 from hodoku.config import StepSearchConfig 

21 

22_MAX_CHAIN = 20 # maximum chain length (number of nodes) 

23 

24 

25# --------------------------------------------------------------------------- 

26# Link building 

27# --------------------------------------------------------------------------- 

28 

29def _build_x_links(grid: Grid, digit: int) -> list[list[tuple[int, bool]]]: 

30 """Return links[cell] = [(neighbor, is_strong), ...] for one digit. 

31 

32 Strong link: the unit has exactly 2 candidates for this digit. 

33 Block neighbors that also share a row or col with the source cell are 

34 skipped — the row/col link already covers them (mirrors Java's constr==2 

35 deduplication in getAllLinks). 

36 """ 

37 cand_set = grid.candidate_sets[digit] 

38 links: list[list[tuple[int, bool]]] = [[] for _ in range(81)] 

39 

40 for cell in range(81): 

41 if not (cand_set >> cell & 1): 

42 continue 

43 r_c, c_c, _ = CONSTRAINTS[cell] 

44 row_u, col_u, blk_u = CELL_CONSTRAINTS[cell] 

45 

46 for unit_idx in (row_u, col_u, blk_u): 

47 is_strong = grid.free[unit_idx][digit] == 2 

48 is_block = unit_idx >= 18 

49 for nb in ALL_UNITS[unit_idx]: 

50 if nb == cell: 

51 continue 

52 if not (cand_set >> nb & 1): 

53 continue 

54 if is_block: 

55 r_nb, c_nb, _ = CONSTRAINTS[nb] 

56 if r_nb == r_c or c_nb == c_c: 

57 continue # already covered by row/col link 

58 links[cell].append((nb, is_strong)) 

59 

60 return links 

61 

62 

63# --------------------------------------------------------------------------- 

64# Sort key matching HoDoKu's SolutionStep.compareTo 

65# --------------------------------------------------------------------------- 

66 

67def _elim_sort_key(step: SolutionStep) -> int: 

68 """Weighted index sum used by HoDoKu to order steps with equal elim count. 

69 

70 Java's getCandidateString() sorts candidatesToDelete by (value, index) via 

71 Collections.sort before compareTo/getIndexSumme is ever called. So the 

72 weighted sum is computed over the *sorted* list, not insertion order. 

73 

74 Formula: sum += cand.index * offset + cand.value; offset starts at 1, +=80 

75 """ 

76 total = 0 

77 offset = 1 

78 for c in sorted(step.candidates_to_delete, key=lambda c: (c.value, c.index)): 

79 total += c.index * offset + c.value 

80 offset += 80 

81 return total 

82 

83 

84def _step_sort_key(step: SolutionStep) -> tuple: 

85 return (-len(step.candidates_to_delete), _elim_sort_key(step)) 

86 

87 

88# --------------------------------------------------------------------------- 

89# Solver 

90# --------------------------------------------------------------------------- 

91 

92class ChainSolver: 

93 """X-Chain, XY-Chain, Remote Pair.""" 

94 

95 def __init__(self, grid: Grid, search_config: StepSearchConfig | None = None) -> None: 

96 self.grid = grid 

97 if search_config is not None: 

98 self._max_chain = search_config.chain_max_length 

99 else: 

100 self._max_chain = _MAX_CHAIN 

101 

102 def get_step(self, sol_type: SolutionType) -> SolutionStep | None: 

103 if sol_type == SolutionType.TURBOT_FISH: 

104 return self._find_turbot_fish() 

105 if sol_type == SolutionType.X_CHAIN: 

106 return self._find_x_chain() 

107 if sol_type == SolutionType.XY_CHAIN: 

108 return self._find_xy_chain() 

109 if sol_type == SolutionType.REMOTE_PAIR: 

110 return self._find_remote_pair() 

111 return None 

112 

113 def find_all(self, sol_type: SolutionType) -> list[SolutionStep]: 

114 if sol_type == SolutionType.TURBOT_FISH: 

115 return self._find_x_chain_impl_all(SolutionType.TURBOT_FISH, max_nodes=4) 

116 if sol_type == SolutionType.X_CHAIN: 

117 return self._find_x_chain_impl_all(SolutionType.X_CHAIN, max_nodes=self._max_chain) 

118 if sol_type == SolutionType.XY_CHAIN: 

119 return self._find_xy_type_all(SolutionType.XY_CHAIN) 

120 if sol_type == SolutionType.REMOTE_PAIR: 

121 return self._find_xy_type_all(SolutionType.REMOTE_PAIR) 

122 return [] 

123 

124 # ------------------------------------------------------------------ 

125 # Turbot Fish (X-Chain restricted to 3 links / 4 nodes) 

126 # ------------------------------------------------------------------ 

127 

128 def _find_turbot_fish(self) -> SolutionStep | None: 

129 """Find the best Turbot Fish (X-Chain with at most 3 links).""" 

130 return self._find_x_chain_impl(SolutionType.TURBOT_FISH, max_nodes=4) 

131 

132 # ------------------------------------------------------------------ 

133 # X-Chain 

134 # ------------------------------------------------------------------ 

135 

136 def _find_x_chain(self) -> SolutionStep | None: 

137 """Find the best X-Chain elimination. 

138 

139 Collects all valid chains, deduplicates by elimination set (keeping 

140 shortest per set), then returns the one ranked first by HoDoKu's 

141 comparator (most eliminations, then lowest weighted index sum). 

142 """ 

143 return self._find_x_chain_impl(SolutionType.X_CHAIN, max_nodes=self._max_chain) 

144 

145 def _find_x_chain_impl(self, sol_type: SolutionType, max_nodes: int) -> SolutionStep | None: 

146 grid = self.grid 

147 # elim_key → (chain_length, step): shortest chain per elimination set 

148 deletes_map: dict[tuple, tuple[int, SolutionStep]] = {} 

149 

150 for digit in range(1, 10): 

151 cand_set = grid.candidate_sets[digit] 

152 if not cand_set: 

153 continue 

154 

155 links = _build_x_links(grid, digit) 

156 

157 tmp = cand_set 

158 while tmp: 

159 lsb = tmp & -tmp 

160 start = lsb.bit_length() - 1 

161 tmp ^= lsb 

162 

163 start_buddies = grid.candidate_sets[digit] & BUDDIES[start] 

164 if not start_buddies: 

165 continue 

166 

167 for nb, is_strong in links[start]: 

168 if not is_strong: 

169 continue # X-Chain must start with a strong link 

170 

171 chain: list[int] = [start, nb] 

172 chain_set: set[int] = {nb} # tracks chain[1:] for lasso check 

173 

174 self._dfs_x( 

175 digit, links, chain, chain_set, 

176 strong_only=False, # just placed a strong link; next: weak 

177 start=start, 

178 start_buddies=start_buddies, 

179 deletes_map=deletes_map, 

180 sol_type=sol_type, 

181 max_nodes=max_nodes, 

182 ) 

183 

184 if not deletes_map: 

185 return None 

186 

187 steps = [step for _, step in deletes_map.values()] 

188 steps.sort(key=_step_sort_key) 

189 return steps[0] 

190 

191 def _find_x_chain_impl_all(self, sol_type: SolutionType, max_nodes: int) -> list[SolutionStep]: 

192 """Like _find_x_chain_impl but returns ALL steps (sorted by quality).""" 

193 grid = self.grid 

194 deletes_map: dict[tuple, tuple[int, SolutionStep]] = {} 

195 

196 for digit in range(1, 10): 

197 cand_set = grid.candidate_sets[digit] 

198 if not cand_set: 

199 continue 

200 links = _build_x_links(grid, digit) 

201 tmp = cand_set 

202 while tmp: 

203 lsb = tmp & -tmp 

204 start = lsb.bit_length() - 1 

205 tmp ^= lsb 

206 start_buddies = grid.candidate_sets[digit] & BUDDIES[start] 

207 if not start_buddies: 

208 continue 

209 for nb, is_strong in links[start]: 

210 if not is_strong: 

211 continue 

212 chain: list[int] = [start, nb] 

213 chain_set: set[int] = {nb} 

214 self._dfs_x( 

215 digit, links, chain, chain_set, 

216 strong_only=False, 

217 start=start, 

218 start_buddies=start_buddies, 

219 deletes_map=deletes_map, 

220 sol_type=sol_type, 

221 max_nodes=max_nodes, 

222 ) 

223 

224 steps = [step for _, step in deletes_map.values()] 

225 steps.sort(key=_step_sort_key) 

226 return steps 

227 

228 def _dfs_x( 

229 self, 

230 digit: int, 

231 links: list[list[tuple[int, bool]]], 

232 chain: list[int], 

233 chain_set: set[int], 

234 strong_only: bool, 

235 start: int, 

236 start_buddies: int, 

237 deletes_map: dict, 

238 sol_type: SolutionType = SolutionType.X_CHAIN, 

239 max_nodes: int = _MAX_CHAIN, 

240 ) -> None: 

241 if len(chain) >= max_nodes: 

242 return 

243 

244 current = chain[-1] 

245 

246 for nb, link_is_strong in links[current]: 

247 if strong_only and not link_is_strong: 

248 continue # must use strong link here 

249 

250 is_loop = False 

251 if nb == start: 

252 # Loop back to start — still valid as an X-Chain ending 

253 # (Java: isLoop = true, falls through to chain check). 

254 is_loop = True 

255 elif nb in chain_set: 

256 continue # lasso: revisiting a middle cell 

257 

258 # When !strong_only: a strong link is treated as weak (no chain check). 

259 effective_strong = link_is_strong and strong_only 

260 

261 chain.append(nb) 

262 if not is_loop: 

263 chain_set.add(nb) 

264 

265 # Valid chain end: ended on a strong link, at least 3 nodes total. 

266 if effective_strong and len(chain) > 2: 

267 elim = start_buddies & BUDDIES[nb] 

268 if elim: 

269 self._record(digit, sol_type, chain, elim, deletes_map) 

270 

271 # Don't recurse further if this was a loop (Java: isLoop → stop) 

272 if not is_loop: 

273 self._dfs_x( 

274 digit, links, chain, chain_set, 

275 strong_only=not strong_only, 

276 start=start, 

277 start_buddies=start_buddies, 

278 deletes_map=deletes_map, 

279 sol_type=sol_type, 

280 max_nodes=max_nodes, 

281 ) 

282 

283 chain.pop() 

284 if not is_loop: 

285 chain_set.discard(nb) 

286 

287 # ------------------------------------------------------------------ 

288 # XY-Chain 

289 # ------------------------------------------------------------------ 

290 

291 def _find_xy_chain(self) -> SolutionStep | None: 

292 return self._find_xy_type(SolutionType.XY_CHAIN) 

293 

294 # ------------------------------------------------------------------ 

295 # Remote Pair 

296 # ------------------------------------------------------------------ 

297 

298 def _find_remote_pair(self) -> SolutionStep | None: 

299 return self._find_xy_type(SolutionType.REMOTE_PAIR) 

300 

301 # ------------------------------------------------------------------ 

302 # Shared XY/RP search 

303 # ------------------------------------------------------------------ 

304 

305 def _find_xy_type(self, sol_type: SolutionType) -> SolutionStep | None: 

306 """Find the best XY-Chain or Remote Pair. 

307 

308 XY-Chain: all cells bivalue; strong link = within-cell, weak = inter-cell. 

309 Remote Pair: like XY-Chain but all cells must share identical candidate pair; 

310 minimum 4 cells (8 chain nodes). 

311 """ 

312 grid = self.grid 

313 is_rp = sol_type == SolutionType.REMOTE_PAIR 

314 deletes_map: dict[tuple, tuple[int, SolutionStep]] = {} 

315 

316 for start in range(81): 

317 if grid.values[start] != 0: 

318 continue 

319 # Both XY-Chain and RP require bivalue start cell 

320 cands = [d for d in range(1, 10) if grid.candidate_sets[d] >> start & 1] 

321 if len(cands) != 2: 

322 continue 

323 start_mask = grid.candidates[start] # 9-bit candidate mask for RP matching 

324 

325 d1, d2 = cands[0], cands[1] 

326 

327 for start_cand, other_cand in ((d1, d2), (d2, d1)): 

328 start_buddies = grid.candidate_sets[start_cand] & BUDDIES[start] 

329 if not start_buddies: 

330 continue 

331 if is_rp: 

332 start_buddies2 = grid.candidate_sets[other_cand] & BUDDIES[start] 

333 

334 # chain stores (cell, candidate) pairs 

335 chain: list[tuple[int, int]] = [(start, start_cand), (start, other_cand)] 

336 visited: set[int] = {start} 

337 

338 self._dfs_xy( 

339 grid, chain, visited, 

340 strong_only=False, # just used within-cell strong; next: inter-cell 

341 start=start, 

342 start_cand=start_cand, 

343 start_buddies=start_buddies, 

344 start_buddies2=start_buddies2 if is_rp else 0, 

345 start_mask=start_mask, 

346 sol_type=sol_type, 

347 deletes_map=deletes_map, 

348 ) 

349 

350 if not deletes_map: 

351 return None 

352 

353 steps = [step for _, step in deletes_map.values()] 

354 steps.sort(key=_step_sort_key) 

355 return steps[0] 

356 

357 def _find_xy_type_all(self, sol_type: SolutionType) -> list[SolutionStep]: 

358 """Like _find_xy_type but returns ALL steps (sorted by quality).""" 

359 grid = self.grid 

360 is_rp = sol_type == SolutionType.REMOTE_PAIR 

361 deletes_map: dict[tuple, tuple[int, SolutionStep]] = {} 

362 

363 for start in range(81): 

364 if grid.values[start] != 0: 

365 continue 

366 cands = [d for d in range(1, 10) if grid.candidate_sets[d] >> start & 1] 

367 if len(cands) != 2: 

368 continue 

369 start_mask = grid.candidates[start] 

370 d1, d2 = cands[0], cands[1] 

371 

372 for start_cand, other_cand in ((d1, d2), (d2, d1)): 

373 start_buddies = grid.candidate_sets[start_cand] & BUDDIES[start] 

374 if not start_buddies: 

375 continue 

376 if is_rp: 

377 start_buddies2 = grid.candidate_sets[other_cand] & BUDDIES[start] 

378 else: 

379 start_buddies2 = 0 

380 chain: list[tuple[int, int]] = [(start, start_cand), (start, other_cand)] 

381 visited: set[int] = {start} 

382 self._dfs_xy( 

383 grid, chain, visited, 

384 strong_only=False, 

385 start=start, 

386 start_cand=start_cand, 

387 start_buddies=start_buddies, 

388 start_buddies2=start_buddies2, 

389 start_mask=start_mask, 

390 sol_type=sol_type, 

391 deletes_map=deletes_map, 

392 ) 

393 

394 steps = [step for _, step in deletes_map.values()] 

395 steps.sort(key=_step_sort_key) 

396 return steps 

397 

398 def _dfs_xy( 

399 self, 

400 grid: Grid, 

401 chain: list[tuple[int, int]], 

402 visited: set[int], 

403 strong_only: bool, 

404 start: int, 

405 start_cand: int, 

406 start_buddies: int, 

407 start_buddies2: int, 

408 start_mask: int, 

409 sol_type: SolutionType, 

410 deletes_map: dict, 

411 ) -> None: 

412 if len(chain) >= self._max_chain: 

413 return 

414 

415 current_cell, current_cand = chain[-1] 

416 is_rp = sol_type == SolutionType.REMOTE_PAIR 

417 

418 if strong_only: 

419 # Within-cell strong link: move to the other candidate in current_cell. 

420 # (bivalue cell guaranteed — current_cell is always bivalue for XY/RP) 

421 other_cand = next( 

422 d for d in range(1, 10) 

423 if d != current_cand and grid.candidate_sets[d] >> current_cell & 1 

424 ) 

425 

426 chain.append((current_cell, other_cand)) 

427 

428 # Check: valid chain end when other_cand == start_cand and len > 2 

429 # (mirrors: stackLevel > 1 && newLinkIsStrong && newLinkCandidate == startCandidate) 

430 if other_cand == start_cand and len(chain) > 2: 

431 if is_rp and len(chain) >= 8: 

432 # Java only enters checkRemotePairs when startCandidate 

433 # eliminations exist (m1/m2 != 0). Mirror that guard so 

434 # the DFS exploration order—and thus which startCandidate 

435 # wins the dedup race—matches HoDoKu's. 

436 elim_rp = start_buddies & BUDDIES[current_cell] 

437 if elim_rp: 

438 self._check_rp(chain, start_cand, start_buddies, start_buddies2, deletes_map) 

439 elif not is_rp: 

440 elim = start_buddies & BUDDIES[current_cell] 

441 if elim: 

442 self._record_xy(start_cand, chain, elim, deletes_map) 

443 

444 self._dfs_xy( 

445 grid, chain, visited, strong_only=False, 

446 start=start, start_cand=start_cand, 

447 start_buddies=start_buddies, start_buddies2=start_buddies2, 

448 start_mask=start_mask, sol_type=sol_type, deletes_map=deletes_map, 

449 ) 

450 chain.pop() 

451 

452 else: 

453 # Inter-cell weak link: move to another bivalue cell with current_cand. 

454 r_c, c_c, _ = CONSTRAINTS[current_cell] 

455 row_u, col_u, blk_u = CELL_CONSTRAINTS[current_cell] 

456 cand_set = grid.candidate_sets[current_cand] 

457 

458 for unit_idx in (row_u, col_u, blk_u): 

459 is_block = unit_idx >= 18 

460 for nb in ALL_UNITS[unit_idx]: 

461 if nb == current_cell: 

462 continue 

463 if not (cand_set >> nb & 1): 

464 continue 

465 # Must be bivalue 

466 if grid.candidates[nb].bit_count() != 2: 

467 continue 

468 # Remote Pair: must have same two candidates 

469 if is_rp and grid.candidates[nb] != start_mask: 

470 continue 

471 # Block dedup: skip if also connected via row/col 

472 if is_block: 

473 r_nb, c_nb, _ = CONSTRAINTS[nb] 

474 if r_nb == r_c or c_nb == c_c: 

475 continue 

476 if nb == start: 

477 continue # nice loop 

478 if nb in visited: 

479 continue # lasso 

480 

481 chain.append((nb, current_cand)) 

482 visited.add(nb) 

483 

484 self._dfs_xy( 

485 grid, chain, visited, strong_only=True, 

486 start=start, start_cand=start_cand, 

487 start_buddies=start_buddies, start_buddies2=start_buddies2, 

488 start_mask=start_mask, sol_type=sol_type, deletes_map=deletes_map, 

489 ) 

490 

491 chain.pop() 

492 visited.discard(nb) 

493 

494 def _check_rp( 

495 self, 

496 chain: list[tuple[int, int]], 

497 start_cand: int, 

498 start_buddies: int, 

499 start_buddies2: int, 

500 deletes_map: dict, 

501 ) -> None: 

502 """Record a Remote Pair step. Eliminates both candidates from shared buddies.""" 

503 # For a 4-cell chain (len=8): check start cell vs end cell only. 

504 # For longer chains: also check all pairs of cells with opposite polarity. 

505 # Positions in chain: 0,1 = cell0; 2,3 = cell1; 4,5 = cell2; ... 

506 # Even node-pair indices (0,2,4,...) are the "entry" candidates. 

507 # Pairs with opposite polarity: i and j where (j-i) ≡ 2 (mod 4) and j-i >= 6. 

508 step = SolutionStep(SolutionType.REMOTE_PAIR) 

509 other_cand = chain[1][1] # the other start candidate 

510 

511 step.add_value(start_cand) 

512 step.add_value(other_cand) 

513 

514 n = len(chain) 

515 elim1_mask = 0 

516 elim2_mask = 0 

517 cand_set1 = self.grid.candidate_sets[start_cand] 

518 cand_set2 = self.grid.candidate_sets[other_cand] 

519 for i in range(0, n, 2): 

520 cell_i = chain[i][0] 

521 for j in range(i + 6, n, 4): 

522 cell_j = chain[j][0] 

523 shared = BUDDIES[cell_i] & BUDDIES[cell_j] 

524 elim1_mask |= shared & cand_set1 

525 elim2_mask |= shared & cand_set2 

526 

527 if not elim1_mask and not elim2_mask: 

528 return 

529 

530 tmp = elim1_mask 

531 while tmp: 

532 lsb = tmp & -tmp 

533 step.add_candidate_to_delete(lsb.bit_length() - 1, start_cand) 

534 tmp ^= lsb 

535 tmp = elim2_mask 

536 while tmp: 

537 lsb = tmp & -tmp 

538 step.add_candidate_to_delete(lsb.bit_length() - 1, other_cand) 

539 tmp ^= lsb 

540 

541 key = tuple(sorted((c.index, c.value) for c in step.candidates_to_delete)) 

542 old = deletes_map.get(key) 

543 if old is None or old[0] > len(chain): 

544 deletes_map[key] = (len(chain), step) 

545 

546 def _record_xy( 

547 self, 

548 digit: int, 

549 chain: list[tuple[int, int]], 

550 elim_mask: int, 

551 deletes_map: dict, 

552 ) -> None: 

553 """Build an XY-Chain step and store if shortest for its elimination set.""" 

554 step = SolutionStep(SolutionType.XY_CHAIN) 

555 step.add_value(digit) 

556 tmp = elim_mask 

557 while tmp: 

558 lsb = tmp & -tmp 

559 step.add_candidate_to_delete(lsb.bit_length() - 1, digit) 

560 tmp ^= lsb 

561 

562 key = tuple(sorted((c.index, c.value) for c in step.candidates_to_delete)) 

563 old = deletes_map.get(key) 

564 if old is None or old[0] > len(chain): 

565 deletes_map[key] = (len(chain), step) 

566 

567 # ------------------------------------------------------------------ 

568 # Shared helper 

569 # ------------------------------------------------------------------ 

570 

571 def _record( 

572 self, 

573 digit: int, 

574 sol_type: SolutionType, 

575 chain: list[int], 

576 elim_mask: int, 

577 deletes_map: dict, 

578 ) -> None: 

579 """Build a step and store it if it's the shortest for its elimination set.""" 

580 step = SolutionStep(sol_type) 

581 step.add_value(digit) 

582 tmp = elim_mask 

583 while tmp: 

584 lsb = tmp & -tmp 

585 step.add_candidate_to_delete(lsb.bit_length() - 1, digit) 

586 tmp ^= lsb 

587 key = tuple(sorted((c.index, c.value) for c in step.candidates_to_delete)) 

588 old = deletes_map.get(key) 

589 if old is None or old[0] > len(chain): 

590 deletes_map[key] = (len(chain), step)