Coverage for src / hodoku / solver / coloring.py: 92%
300 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
1"""Coloring solver: Simple Colors (Trap/Wrap), Multi-Colors.
3Mirrors Java's ColoringSolver.
4"""
6from __future__ import annotations
8from hodoku.core.grid import ALL_UNITS, BUDDIES, CONSTRAINTS, Grid
9from hodoku.core.solution_step import SolutionStep
10from hodoku.core.types import SolutionType
13class ColoringSolver:
14 """Simple Colors and Multi-Colors."""
16 def __init__(self, grid: Grid) -> None:
17 self.grid = grid
19 def get_step(self, sol_type: SolutionType) -> SolutionStep | None:
20 if sol_type in (SolutionType.SIMPLE_COLORS_TRAP, SolutionType.SIMPLE_COLORS_WRAP):
21 return self._find_simple_colors()
22 if sol_type in (SolutionType.MULTI_COLORS_1, SolutionType.MULTI_COLORS_2):
23 return self._find_multi_colors()
24 return None
26 def find_all(self, sol_type: SolutionType) -> list[SolutionStep]:
27 if sol_type == SolutionType.SIMPLE_COLORS_TRAP:
28 return self._find_simple_colors_all(wrap=False)
29 if sol_type == SolutionType.SIMPLE_COLORS_WRAP:
30 return self._find_simple_colors_all(trap=False)
31 if sol_type == SolutionType.MULTI_COLORS_1:
32 return self._find_multi_colors_all(mc2=False)
33 if sol_type == SolutionType.MULTI_COLORS_2:
34 return self._find_multi_colors_all(mc1=False)
35 return []
37 def _find_simple_colors_all(
38 self, trap: bool = True, wrap: bool = True
39 ) -> list[SolutionStep]:
40 results: list[SolutionStep] = []
41 seen_elims: set[tuple] = set()
42 for cand in range(1, 10):
43 pairs = self._do_coloring(cand)
44 for c1, c2 in pairs:
45 if wrap:
46 step = self._check_wrap(cand, c1, c2)
47 if step:
48 key = tuple(sorted(
49 (c.index, c.value) for c in step.candidates_to_delete
50 ))
51 if key not in seen_elims:
52 seen_elims.add(key)
53 results.append(step)
54 if trap:
55 step = self._check_trap(cand, c1, c2)
56 if step:
57 key = tuple(sorted(
58 (c.index, c.value) for c in step.candidates_to_delete
59 ))
60 if key not in seen_elims:
61 seen_elims.add(key)
62 results.append(step)
63 return results
65 def _find_multi_colors_all(
66 self, mc1: bool = True, mc2: bool = True
67 ) -> list[SolutionStep]:
68 results: list[SolutionStep] = []
69 seen_elims: set[tuple] = set()
70 grid = self.grid
71 for cand in range(1, 10):
72 pairs = self._do_coloring(cand)
73 n = len(pairs)
74 for i in range(n):
75 for j in range(n):
76 if i == j:
77 continue
78 a1, a2 = pairs[i]
79 b1, b2 = pairs[j]
81 if mc2:
82 elim_mask = 0
83 if self._set_sees_both(a1, b1, b2):
84 for cell in a1:
85 if grid.candidate_sets[cand] >> cell & 1:
86 elim_mask |= 1 << cell
87 if self._set_sees_both(a2, b1, b2):
88 for cell in a2:
89 if grid.candidate_sets[cand] >> cell & 1:
90 elim_mask |= 1 << cell
91 if elim_mask:
92 elim: set[int] = set()
93 tmp = elim_mask
94 while tmp:
95 lsb = tmp & -tmp
96 elim.add(lsb.bit_length() - 1)
97 tmp ^= lsb
98 step = self._make_mc_step(
99 SolutionType.MULTI_COLORS_2, cand, a1, a2, b1, b2, elim
100 )
101 key = tuple(sorted(
102 (c.index, c.value) for c in step.candidates_to_delete
103 ))
104 if key not in seen_elims:
105 seen_elims.add(key)
106 results.append(step)
108 if mc1:
109 elim_mc1: set[int] = set()
110 if self._sets_intersect_buddies(a1, b1):
111 elim_mc1 |= self._trap_elim(cand, a2, b2)
112 if self._sets_intersect_buddies(a1, b2):
113 elim_mc1 |= self._trap_elim(cand, a2, b1)
114 if self._sets_intersect_buddies(a2, b1):
115 elim_mc1 |= self._trap_elim(cand, a1, b2)
116 if self._sets_intersect_buddies(a2, b2):
117 elim_mc1 |= self._trap_elim(cand, a1, b1)
118 if elim_mc1:
119 step = self._make_mc_step(
120 SolutionType.MULTI_COLORS_1, cand, a1, a2, b1, b2, elim_mc1
121 )
122 key = tuple(sorted(
123 (c.index, c.value) for c in step.candidates_to_delete
124 ))
125 if key not in seen_elims:
126 seen_elims.add(key)
127 results.append(step)
128 return results
130 # ------------------------------------------------------------------
131 # Coloring graph builder
132 # ------------------------------------------------------------------
134 def _do_coloring(self, cand: int) -> list[tuple[frozenset[int], frozenset[int]]]:
135 """Partition cells with cand into color pairs via conjugate links.
137 Returns a list of (C1, C2) frozensets. Each pair is one connected
138 component of the conjugate-pair graph; C1 and C2 alternate along links.
139 Single-cell components (no conjugate link) are discarded.
140 """
141 grid = self.grid
142 # Cells that carry cand and belong to at least one conjugate pair
143 start: set[int] = set()
144 all_cells_with_cand: int = grid.candidate_sets[cand]
145 tmp = all_cells_with_cand
146 while tmp:
147 lsb = tmp & -tmp
148 cell = lsb.bit_length() - 1
149 r, c, b = CONSTRAINTS[cell]
150 if (
151 grid.free[r][cand] == 2
152 or grid.free[9 + c][cand] == 2
153 or grid.free[18 + b][cand] == 2
154 ):
155 start.add(cell)
156 tmp ^= lsb
158 color_pairs: list[tuple[frozenset[int], frozenset[int]]] = []
160 while start:
161 c1: set[int] = set()
162 c2: set[int] = set()
163 seed = min(start) # match Java: always pick lowest-index cell first
164 self._color_dfs(seed, cand, True, start, c1, c2)
165 if c1 and c2:
166 color_pairs.append((frozenset(c1), frozenset(c2)))
167 return color_pairs
169 def _color_dfs(
170 self,
171 cell: int,
172 cand: int,
173 on: bool,
174 remaining: set[int],
175 c1: set[int],
176 c2: set[int],
177 ) -> None:
178 """DFS coloring: assign cell to c1 (on=True) or c2, then recurse."""
179 if cell not in remaining:
180 return
181 remaining.discard(cell)
182 (c1 if on else c2).add(cell)
183 r, c, b = CONSTRAINTS[cell]
184 for unit_idx in (r, 9 + c, 18 + b):
185 partner = self._conjugate(cell, cand, unit_idx)
186 if partner != -1:
187 self._color_dfs(partner, cand, not on, remaining, c1, c2)
189 def _conjugate(self, cell: int, cand: int, unit_idx: int) -> int:
190 """Return the other cell in a conjugate pair in unit_idx, or -1."""
191 grid = self.grid
192 if grid.free[unit_idx][cand] != 2:
193 return -1
194 pair = grid.candidate_sets[cand] & self._unit_mask(unit_idx)
195 other = pair & ~(1 << cell)
196 if not other:
197 return -1
198 return other.bit_length() - 1
200 def _unit_mask(self, unit_idx: int) -> int:
201 """Return 81-bit mask for all cells in unit unit_idx (0-26)."""
202 mask = 0
203 for cell in ALL_UNITS[unit_idx]:
204 mask |= 1 << cell
205 return mask
207 # ------------------------------------------------------------------
208 # Simple Colors
209 # ------------------------------------------------------------------
211 def _find_simple_colors(self) -> SolutionStep | None:
212 """Return the first Simple Colors step found.
214 Mirrors Java's findSimpleColorSteps(onlyOne=true): for each candidate
215 1-9, check wrap first then trap for each coloring pair. Return the
216 very first step found (wrap or trap) immediately.
217 """
218 for cand in range(1, 10):
219 pairs = self._do_coloring(cand)
220 for c1, c2 in pairs:
221 # Wrap: two same-color cells see each other → eliminate that color
222 step = self._check_wrap(cand, c1, c2)
223 if step:
224 return step
226 # Trap: any cell that sees both a c1 and a c2 cell
227 step = self._check_trap(cand, c1, c2)
228 if step:
229 return step
230 return None
232 def _check_wrap(
233 self, cand: int, c1: frozenset[int], c2: frozenset[int]
234 ) -> SolutionStep | None:
235 grid = self.grid
236 elim_mask = 0
237 for color, other in ((c1, c2), (c2, c1)):
238 color_list = list(color)
239 for i in range(len(color_list) - 1):
240 for j in range(i + 1, len(color_list)):
241 if BUDDIES[color_list[i]] >> color_list[j] & 1:
242 # This color is wrong — eliminate all of it
243 for cell in color:
244 if grid.candidate_sets[cand] >> cell & 1:
245 elim_mask |= 1 << cell
246 break
247 else:
248 continue
249 break
250 if not elim_mask:
251 return None
252 step = SolutionStep(SolutionType.SIMPLE_COLORS_WRAP)
253 step.add_value(cand)
254 self._add_color_candidates(step, c1, cand, 0)
255 self._add_color_candidates(step, c2, cand, 1)
256 tmp = elim_mask
257 while tmp:
258 lsb = tmp & -tmp
259 step.add_candidate_to_delete(lsb.bit_length() - 1, cand)
260 tmp ^= lsb
261 return step
263 def _check_trap(
264 self, cand: int, c1: frozenset[int], c2: frozenset[int]
265 ) -> SolutionStep | None:
266 grid = self.grid
267 # Build union of buddies for each color
268 buddies_c1 = 0
269 for cell in c1:
270 buddies_c1 |= BUDDIES[cell]
271 buddies_c2 = 0
272 for cell in c2:
273 buddies_c2 |= BUDDIES[cell]
275 elim_mask = (
276 grid.candidate_sets[cand]
277 & buddies_c1
278 & buddies_c2
279 )
280 # Remove the coloring cells themselves
281 for cell in c1 | c2:
282 elim_mask &= ~(1 << cell)
284 if not elim_mask:
285 return None
286 step = SolutionStep(SolutionType.SIMPLE_COLORS_TRAP)
287 step.add_value(cand)
288 self._add_color_candidates(step, c1, cand, 0)
289 self._add_color_candidates(step, c2, cand, 1)
290 tmp = elim_mask
291 while tmp:
292 lsb = tmp & -tmp
293 step.add_candidate_to_delete(lsb.bit_length() - 1, cand)
294 tmp ^= lsb
295 return step
297 # ------------------------------------------------------------------
298 # Multi-Colors
299 # ------------------------------------------------------------------
301 def _find_multi_colors(self) -> SolutionStep | None:
302 """Return the first Multi-Colors step.
304 Mirrors Java: for each (i,j) pair, first try MC2 (accumulating
305 eliminations from both a1 and a2), then MC1 (accumulating from
306 all 4 pair combinations). Return on the first non-empty result.
307 """
308 grid = self.grid
309 for cand in range(1, 10):
310 pairs = self._do_coloring(cand)
311 n = len(pairs)
312 for i in range(n):
313 for j in range(n):
314 if i == j:
315 continue
316 a1, a2 = pairs[i]
317 b1, b2 = pairs[j]
319 # MC type 2: accumulate eliminations from a1 and a2 that each
320 # see both colors of B.
321 elim_mask = 0
322 if self._set_sees_both(a1, b1, b2):
323 for cell in a1:
324 if grid.candidate_sets[cand] >> cell & 1:
325 elim_mask |= 1 << cell
326 if self._set_sees_both(a2, b1, b2):
327 for cell in a2:
328 if grid.candidate_sets[cand] >> cell & 1:
329 elim_mask |= 1 << cell
330 if elim_mask:
331 elim: set[int] = set()
332 tmp = elim_mask
333 while tmp:
334 lsb = tmp & -tmp
335 elim.add(lsb.bit_length() - 1)
336 tmp ^= lsb
337 return self._make_mc_step(SolutionType.MULTI_COLORS_2, cand, a1, a2, b1, b2, elim)
339 # MC type 1: accumulate eliminations from all 4 pair combinations.
340 elim_mc1: set[int] = set()
341 if self._sets_intersect_buddies(a1, b1):
342 elim_mc1 |= self._trap_elim(cand, a2, b2)
343 if self._sets_intersect_buddies(a1, b2):
344 elim_mc1 |= self._trap_elim(cand, a2, b1)
345 if self._sets_intersect_buddies(a2, b1):
346 elim_mc1 |= self._trap_elim(cand, a1, b2)
347 if self._sets_intersect_buddies(a2, b2):
348 elim_mc1 |= self._trap_elim(cand, a1, b1)
349 if elim_mc1:
350 return self._make_mc_step(SolutionType.MULTI_COLORS_1, cand, a1, a2, b1, b2, elim_mc1)
351 return None
353 def _set_sees_both(
354 self, color: frozenset[int], b1: frozenset[int], b2: frozenset[int]
355 ) -> bool:
356 """True if the cells in color collectively see at least one cell in b1
357 AND at least one cell in b2."""
358 sees_b1 = sees_b2 = False
359 for cell in color:
360 for b in b1:
361 if BUDDIES[cell] >> b & 1:
362 sees_b1 = True
363 break
364 for b in b2:
365 if BUDDIES[cell] >> b & 1:
366 sees_b2 = True
367 break
368 if sees_b1 and sees_b2:
369 return True
370 return False
372 def _sets_intersect_buddies(
373 self, s1: frozenset[int], s2: frozenset[int]
374 ) -> bool:
375 """True if any cell in s1 sees any cell in s2."""
376 for a in s1:
377 for b in s2:
378 if BUDDIES[a] >> b & 1:
379 return True
380 return False
382 def _trap_elim(
383 self, cand: int, color_a: frozenset[int], color_b: frozenset[int]
384 ) -> set[int]:
385 """Cells that see both color_a and color_b (for MC type-1 elim)."""
386 grid = self.grid
387 buddies_a = 0
388 for cell in color_a:
389 buddies_a |= BUDDIES[cell]
390 buddies_b = 0
391 for cell in color_b:
392 buddies_b |= BUDDIES[cell]
393 elim_mask = grid.candidate_sets[cand] & buddies_a & buddies_b
394 for cell in color_a | color_b:
395 elim_mask &= ~(1 << cell)
396 result: set[int] = set()
397 tmp = elim_mask
398 while tmp:
399 lsb = tmp & -tmp
400 result.add(lsb.bit_length() - 1)
401 tmp ^= lsb
402 return result
404 def _cells_with_cand(self, cand: int) -> set[int]:
405 result: set[int] = set()
406 tmp = self.grid.candidate_sets[cand]
407 while tmp:
408 lsb = tmp & -tmp
409 result.add(lsb.bit_length() - 1)
410 tmp ^= lsb
411 return result
413 def _make_mc_step(
414 self,
415 sol_type: SolutionType,
416 cand: int,
417 a1: frozenset[int],
418 a2: frozenset[int],
419 b1: frozenset[int],
420 b2: frozenset[int],
421 elim: set[int],
422 ) -> SolutionStep:
423 step = SolutionStep(sol_type)
424 step.add_value(cand)
425 self._add_color_candidates(step, a1, cand, 0)
426 self._add_color_candidates(step, a2, cand, 1)
427 self._add_color_candidates(step, b1, cand, 2)
428 self._add_color_candidates(step, b2, cand, 3)
429 for cell in sorted(elim):
430 step.add_candidate_to_delete(cell, cand)
431 return step
433 # ------------------------------------------------------------------
434 # Utility
435 # ------------------------------------------------------------------
437 def _add_color_candidates(
438 self, step: SolutionStep, color: frozenset[int], cand: int, color_idx: int
439 ) -> None:
440 """Add all cells in color to step's color_candidates dict."""
441 for cell in sorted(color):
442 step.color_candidates[cell] = color_idx