Coverage for src / hodoku / solver / templates.py: 97%
115 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
1"""TemplateSolver — Template Set and Template Delete.
3A *template* is one valid assignment of a single digit to the 81 cells:
4exactly one cell per row, one per column, and one per box. There are
546 656 such templates in total.
7Template Set — cells that appear in EVERY valid template for a digit must
8 contain that digit.
9Template Delete — cells that appear in NO valid template for a digit cannot
10 contain that digit (candidate can be eliminated).
12The refinement pass (initLists=True in Java) iteratively removes templates
13for digit j that overlap with the mandatory cells of another digit k, then
14recomputes the mandatory/forbidden masks. This is the variant always called
15in HoDoKu's getAllTemplates() / getStep().
17Port notes:
18- Templates are generated once at import time via recursive backtracking.
19- CellSets are plain Python ints (81-bit bitsets).
20- Each Template-Set SolutionStep pairs add_index(cell) with add_value(digit)
21 for EVERY cell so that zip(step.indices, step.values) is well-defined.
22"""
24from __future__ import annotations
26from hodoku.core.grid import Grid
27from hodoku.core.solution_step import SolutionStep
28from hodoku.core.types import SolutionType
30_ALL_CELLS: int = (1 << 81) - 1
33# ---------------------------------------------------------------------------
34# Template generation — runs once at module import
35# ---------------------------------------------------------------------------
37def _generate_templates() -> list[int]:
38 """Return all 46 656 valid templates as 81-bit integers."""
39 templates: list[int] = []
41 def backtrack(row: int, used_cols: int, used_boxes: int, cellset: int) -> None:
42 if row == 9:
43 templates.append(cellset)
44 return
45 for col in range(9):
46 if used_cols >> col & 1:
47 continue
48 box = (row // 3) * 3 + (col // 3)
49 if used_boxes >> box & 1:
50 continue
51 backtrack(
52 row + 1,
53 used_cols | (1 << col),
54 used_boxes | (1 << box),
55 cellset | (1 << (row * 9 + col)),
56 )
58 backtrack(0, 0, 0, 0)
59 return templates
62_TEMPLATES: list[int] = _generate_templates()
65# ---------------------------------------------------------------------------
66# Core computation
67# ---------------------------------------------------------------------------
69def _init_cand_templates(grid: Grid) -> tuple[list[int], list[int]]:
70 """Compute setValueTemplates and delCandTemplates for all 9 digits.
72 Returns (set_value[0..9], del_cand[0..9]), where index 0 is unused.
73 After this call:
74 - set_value[d] & ~positions[d] → cells where d must be placed
75 - del_cand[d] & candidate_sets[d] → candidates to eliminate
76 """
77 # positions[d]: cells where digit d is already placed
78 positions = [0] * 10
79 for cell in range(81):
80 d = grid.values[cell]
81 if d:
82 positions[d] |= 1 << cell
84 # allowed[d]: cells where d is still a candidate
85 allowed = grid.candidate_sets # list[int], index 0 unused
87 # forbidden[d]: cells where d is neither placed nor a candidate
88 forbidden = [0] * 10
89 for d in range(1, 10):
90 forbidden[d] = (~(positions[d] | allowed[d])) & _ALL_CELLS
92 # --- Initial filtering pass ---
93 # set_value[d] = AND of all valid templates (cells common to all)
94 # del_cand[d] = OR of all valid templates (cells in at least one)
95 set_value = [_ALL_CELLS] * 10
96 del_cand = [0] * 10
97 cand_lists: list[list[int]] = [[] for _ in range(10)]
99 for t in _TEMPLATES:
100 for d in range(1, 10):
101 if (positions[d] & t) != positions[d]:
102 continue # template doesn't cover all already-placed cells
103 if (forbidden[d] & t) != 0:
104 continue # template touches a forbidden cell
105 set_value[d] &= t
106 del_cand[d] |= t
107 cand_lists[d].append(t)
109 # --- Refinement pass (mirrors Java initLists=True) ---
110 # Iteratively remove templates for digit j that overlap with the
111 # mandatory positions (set_value) of any other digit k.
112 removals = 1
113 while removals:
114 removals = 0
115 for j in range(1, 10):
116 set_value[j] = _ALL_CELLS
117 del_cand[j] = 0
118 new_list: list[int] = []
119 for t in cand_lists[j]:
120 removed = False
121 for k in range(1, 10):
122 if k != j and (t & set_value[k]) != 0:
123 removed = True
124 removals += 1
125 break
126 if not removed:
127 set_value[j] &= t
128 del_cand[j] |= t
129 new_list.append(t)
130 cand_lists[j] = new_list
132 # Complement: del_cand[d] now = cells in NO valid template
133 for d in range(1, 10):
134 del_cand[d] = (~del_cand[d]) & _ALL_CELLS
136 return set_value, del_cand
139# ---------------------------------------------------------------------------
140# Solver class
141# ---------------------------------------------------------------------------
143class TemplateSolver:
144 def __init__(self, grid: Grid) -> None:
145 self.grid = grid
147 # --- Public interface (mirrors other solvers) ---
149 def get_step(self, sol_type: SolutionType) -> SolutionStep | None:
150 steps = self.find_all(sol_type)
151 return steps[0] if steps else None
153 def find_all(self, sol_type: SolutionType) -> list[SolutionStep]:
154 set_value, del_cand = _init_cand_templates(self.grid)
155 positions = [0] * 10
156 for cell in range(81):
157 d = self.grid.values[cell]
158 if d:
159 positions[d] |= 1 << cell
161 if sol_type is SolutionType.TEMPLATE_SET:
162 return self._find_template_set(set_value, positions)
163 if sol_type is SolutionType.TEMPLATE_DEL:
164 return self._find_template_del(del_cand)
165 return []
167 # --- Internal helpers ---
169 def _find_template_set(
170 self, set_value: list[int], positions: list[int]
171 ) -> list[SolutionStep]:
172 steps: list[SolutionStep] = []
173 for d in range(1, 10):
174 cells_mask = set_value[d] & ~positions[d] & _ALL_CELLS
175 if not cells_mask:
176 continue
177 step = SolutionStep(SolutionType.TEMPLATE_SET)
178 mask = cells_mask
179 while mask:
180 lsb = mask & -mask
181 cell = lsb.bit_length() - 1
182 step.add_index(cell)
183 step.add_value(d)
184 mask ^= lsb
185 steps.append(step)
186 return steps
188 def _find_template_del(self, del_cand: list[int]) -> list[SolutionStep]:
189 steps: list[SolutionStep] = []
190 for d in range(1, 10):
191 elim_mask = del_cand[d] & self.grid.candidate_sets[d]
192 if not elim_mask:
193 continue
194 step = SolutionStep(SolutionType.TEMPLATE_DEL)
195 step.add_value(d)
196 mask = elim_mask
197 while mask:
198 lsb = mask & -mask
199 cell = lsb.bit_length() - 1
200 step.add_candidate_to_delete(cell, d)
201 mask ^= lsb
202 steps.append(step)
203 return steps