Coverage for src / hodoku / generator / generator.py: 96%
375 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 08:35 +0000
1"""Backtracking solver and puzzle generator.
3Port of Java's ``SudokuGenerator`` — a bit-based backtracking solver that uses
4the Grid's naked-single / hidden-single queues to propagate constraints during
5search. Used by the generator to:
7* Check uniqueness (``valid_solution`` / ``get_number_of_solutions``).
8* Produce the solution array stored on the Grid.
9* Generate full grids and derive puzzles by symmetric clue removal.
11Reference: ``generator/SudokuGenerator.java`` lines 40–747.
12"""
14from __future__ import annotations
16import os
17import random
19from hodoku.core.grid import (
20 ALL_UNIT_MASKS,
21 BUDDIES,
22 CELL_CONSTRAINTS,
23 DIGIT_MASKS,
24 Grid,
25 LENGTH,
26)
28# ---------------------------------------------------------------------------
29# Optional C accelerator
30# ---------------------------------------------------------------------------
32try:
33 if os.environ.get("HODOKU_NO_ACCEL"):
34 raise ImportError("disabled by HODOKU_NO_ACCEL")
35 from hodoku.generator import _gen_accel
36 _gen_accel.init_tables(
37 list(BUDDIES),
38 list(CELL_CONSTRAINTS),
39 list(ALL_UNIT_MASKS),
40 )
41except ImportError:
42 _gen_accel = None
44# ---------------------------------------------------------------------------
45# Precomputed lookup: candidate-mask → list of digits present
46# ---------------------------------------------------------------------------
48_POSSIBLE_VALUES: tuple[tuple[int, ...], ...] = tuple(
49 tuple(d for d in range(1, 10) if mask & (1 << (d - 1)))
50 for mask in range(0x200)
51)
54# ---------------------------------------------------------------------------
55# Stack entry for iterative backtracking
56# ---------------------------------------------------------------------------
58class _StackEntry:
59 """One level of the recursion stack."""
61 __slots__ = ("grid", "index", "candidates", "cand_index")
63 def __init__(self) -> None:
64 self.grid: Grid = Grid()
65 self.index: int = 0
66 self.candidates: tuple[int, ...] = ()
67 self.cand_index: int = 0
70# ---------------------------------------------------------------------------
71# Grid helpers – validity-returning cell placement for the backtracker
72# ---------------------------------------------------------------------------
74def _del_cand_valid(grid: Grid, index: int, digit: int) -> bool:
75 """Remove *digit* as a candidate from *index*.
77 Returns ``False`` if the cell is left with **zero** candidates (invalid).
78 Otherwise propagates hidden/naked-single queue entries and returns ``True``.
79 """
80 mask = DIGIT_MASKS[digit]
81 if not (grid.candidates[index] & mask):
82 return True # already absent – nothing to do
84 grid.candidates[index] &= ~mask
85 if grid.candidates[index] == 0:
86 return False # cell has no candidates left → invalid
88 grid.candidate_sets[digit] &= ~(1 << index)
90 for c in CELL_CONSTRAINTS[index]:
91 grid.free[c][digit] -= 1
92 if grid.free[c][digit] == 1:
93 rem = grid.candidate_sets[digit] & ALL_UNIT_MASKS[c]
94 if rem:
95 cell = (rem & -rem).bit_length() - 1
96 grid.hs_queue.append((cell, digit))
97 # Note: Java also deletes stale HS entries when free==0.
98 # Python handles staleness via the validity check in
99 # _set_all_exposed_singles before processing each queue entry.
101 remaining = grid.candidates[index]
102 if remaining != 0 and (remaining & (remaining - 1)) == 0:
103 # Exactly one candidate left → naked single
104 grid.ns_queue.append((index, remaining.bit_length()))
106 return True
109def _set_cell_valid(grid: Grid, index: int, value: int) -> bool:
110 """Place *value* in *index*, propagate constraints.
112 Returns ``False`` if the puzzle becomes invalid (a buddy loses its last
113 candidate, or a constraint has zero cells left for a digit).
115 Mirrors ``Sudoku2.setCell(index, value, false, false)`` —
116 ``isFixed=false``, ``user=false``.
117 """
118 if grid.values[index] == value:
119 return True
121 valid = True
122 grid.values[index] = value
124 # --- Step 1: eliminate *value* from every buddy ---
125 # Note: we do NOT short-circuit when valid becomes False. Java's setCell
126 # also continues removing candidates from all buddies even after a
127 # contradiction is found. This matches Java's behavior but means we do
128 # unnecessary work after hitting contradictions. Potential optimization
129 # target if backtracker performance becomes an issue.
130 buddies = BUDDIES[index]
131 while buddies:
132 lsb = buddies & -buddies
133 j = lsb.bit_length() - 1
134 buddies ^= lsb
135 if not _del_cand_valid(grid, j, value):
136 valid = False
138 # --- Step 2: clear all remaining candidates from this cell ---
139 old_mask = grid.candidates[index]
140 grid.candidates[index] = 0
141 for d in range(1, 10):
142 if old_mask & DIGIT_MASKS[d]:
143 grid.candidate_sets[d] &= ~(1 << index)
144 for c in CELL_CONSTRAINTS[index]:
145 grid.free[c][d] -= 1
146 if grid.free[c][d] == 1 and d != value:
147 rem = grid.candidate_sets[d] & ALL_UNIT_MASKS[c]
148 if rem:
149 cell = (rem & -rem).bit_length() - 1
150 grid.hs_queue.append((cell, d))
151 elif grid.free[c][d] == 0 and d != value:
152 valid = False
154 return valid
157def _copy_state(dst: Grid, src: Grid) -> None:
158 """Fast state copy for the backtracker (mirrors ``Sudoku2.setBS``).
160 Copies values, candidates, candidate_sets, free; clears queues.
161 Does **not** copy solution or givens (not needed during search).
162 """
163 dst.values[:] = src.values
164 dst.candidates[:] = src.candidates
165 dst.candidate_sets[:] = src.candidate_sets
166 dst.free = [list(row) for row in src.free]
167 dst.ns_queue.clear()
168 dst.hs_queue.clear()
171def _set_all_exposed_singles(grid: Grid) -> bool:
172 """Drain the naked-single and hidden-single queues.
174 Returns ``False`` if the puzzle becomes invalid during propagation.
175 Mirrors ``SudokuGenerator.setAllExposedSingles()``.
176 """
177 valid = True
178 ns_queue = grid.ns_queue
179 hs_queue = grid.hs_queue
181 while True:
182 # First all naked singles
183 while valid and ns_queue:
184 index, value = ns_queue.popleft()
185 if grid.candidates[index] & DIGIT_MASKS[value]:
186 valid = _set_cell_valid(grid, index, value)
188 # Then all hidden singles
189 while valid and hs_queue:
190 index, value = hs_queue.popleft()
191 if grid.candidates[index] & DIGIT_MASKS[value]:
192 valid = _set_cell_valid(grid, index, value)
194 if not valid or (not ns_queue and not hs_queue):
195 break
197 return valid
200# ---------------------------------------------------------------------------
201# SudokuGenerator
202# ---------------------------------------------------------------------------
204class SudokuGenerator:
205 """Bit-based backtracking solver for generation and uniqueness checking."""
207 def __init__(self, rng: random.Random | None = None) -> None:
208 self._stack: list[_StackEntry] = [_StackEntry() for _ in range(82)]
209 self._solution: list[int] = [0] * LENGTH
210 self._solution_count: int = 0
211 self._rand: random.Random = rng if rng is not None else random.Random()
212 self._generate_indices: list[int] = list(range(LENGTH))
213 self._new_full_sudoku: list[int] = [0] * LENGTH
214 self._new_valid_sudoku: list[int] = [0] * LENGTH
216 # ------------------------------------------------------------------
217 # Public API
218 # ------------------------------------------------------------------
220 def valid_solution(self, grid: Grid) -> bool:
221 """Return ``True`` if *grid* has exactly one solution.
223 If unique, the solution is stored on the grid via
224 ``grid.set_solution()``.
225 """
226 self._solve_grid(grid)
227 unique = self._solution_count == 1
228 if unique:
229 grid.set_solution(list(self._solution))
230 return unique
232 def get_number_of_solutions(self, grid: Grid) -> int:
233 """Return 0 (invalid), 1 (unique), or 2 (multiple solutions).
235 If unique, the solution is stored on the grid.
236 """
237 self._solve_grid(grid)
238 if self._solution_count == 1:
239 grid.set_solution(list(self._solution))
240 return self._solution_count
242 def get_solution(self) -> list[int]:
243 """Return the first solution found by the last solve call."""
244 return list(self._solution)
246 def get_solution_count(self) -> int:
247 """Return the solution count from the last solve call."""
248 return self._solution_count
250 def get_solution_as_string(self) -> str:
251 """Return the solution as an 81-character digit string."""
252 return "".join(str(d) for d in self._solution)
254 # ------------------------------------------------------------------
255 # Puzzle generation
256 # ------------------------------------------------------------------
258 def generate_sudoku(
259 self,
260 symmetric: bool = True,
261 pattern: list[bool] | None = None,
262 ) -> str | None:
263 """Generate a new puzzle and return it as an 81-character string.
265 If *pattern* is given, uses a fixed given-pattern (may return ``None``
266 if no valid puzzle is found in MAX_TRIES attempts). Otherwise uses
267 random symmetric clue removal.
269 Mirrors ``SudokuGenerator.generateSudoku(boolean, boolean[])``.
270 """
271 MAX_PATTERN_TRIES = 1_000_000
273 if pattern is None:
274 self._generate_full_grid()
275 self._generate_init_pos(symmetric)
276 else:
277 ok = False
278 for attempt in range(MAX_PATTERN_TRIES):
279 self._generate_full_grid()
280 if self._generate_init_pos_pattern(pattern):
281 ok = True
282 break
283 if not ok:
284 return None
286 return "".join(str(d) for d in self._new_valid_sudoku)
288 def _generate_full_grid(self) -> None:
289 """Retry wrapper for full-grid generation.
291 Mirrors ``SudokuGenerator.generateFullGrid()``.
292 """
293 while not self._do_generate_full_grid():
294 pass
296 def _do_generate_full_grid(self) -> bool:
297 """Generate a random full sudoku grid via backtracking.
299 Mirrors ``SudokuGenerator.doGenerateFullGrid()``. Cells are visited
300 in a random order. If more than 100 backtrack levels are explored,
301 the attempt is aborted and the caller retries with a new shuffle.
302 """
303 act_tries = 0
304 rand = self._rand
305 indices = self._generate_indices
307 # Fisher–Yates-style shuffle (Java's pairwise-swap variant)
308 max_len = len(indices)
309 for i in range(max_len):
310 indices[i] = i
311 for i in range(max_len):
312 idx1 = rand.randrange(max_len)
313 idx2 = rand.randrange(max_len)
314 while idx1 == idx2:
315 idx2 = rand.randrange(max_len)
316 indices[idx1], indices[idx2] = indices[idx2], indices[idx1]
318 # Start with empty grid
319 stack = self._stack
320 stack[0].grid.__init__()
321 level = 0
322 stack[0].index = -1
324 while True:
325 if stack[level].grid.values.count(0) == 0:
326 # Full grid generated
327 self._new_full_sudoku[:] = stack[level].grid.values
328 return True
330 # Find first unsolved cell in random order
331 index = -1
332 vals = stack[level].grid.values
333 for i in range(LENGTH):
334 act_try = indices[i]
335 if vals[act_try] == 0:
336 index = act_try
337 break
339 level += 1
340 stack[level].index = index
341 stack[level].candidates = _POSSIBLE_VALUES[
342 stack[level - 1].grid.candidates[index]
343 ]
344 stack[level].cand_index = 0
346 # Limit backtracking depth
347 act_tries += 1
348 if act_tries > 100:
349 return False
351 # Try candidates at this level
352 done = False
353 while True:
354 while stack[level].cand_index >= len(stack[level].candidates):
355 level -= 1
356 if level <= 0:
357 done = True
358 break
359 if done:
360 break
362 next_cand = stack[level].candidates[stack[level].cand_index]
363 stack[level].cand_index += 1
365 _copy_state(stack[level].grid, stack[level - 1].grid)
367 if not _set_cell_valid(
368 stack[level].grid, stack[level].index, next_cand
369 ):
370 continue
372 if _set_all_exposed_singles(stack[level].grid):
373 break
375 if done:
376 break
378 return False
380 def _generate_init_pos(self, is_symmetric: bool) -> None:
381 """Remove clues from a full grid to create a puzzle.
383 Mirrors ``SudokuGenerator.generateInitPos(boolean)``.
384 Scan-forward random cell selection with 180-degree symmetry.
385 """
386 max_pos_to_fill = 17 # minimum 17 givens
387 used = [False] * 81
388 used_count = 81
389 rand = self._rand
391 full = self._new_full_sudoku
392 valid = self._new_valid_sudoku
393 valid[:] = full
395 remaining_clues = 81
397 while remaining_clues > max_pos_to_fill and used_count > 1:
398 # Scan forward from a random position to find an untried cell
399 i = rand.randrange(81)
400 while True:
401 i = i + 1 if i < 80 else 0
402 if not used[i]:
403 break
404 used[i] = True
405 used_count -= 1
407 if valid[i] == 0:
408 # Already deleted (by symmetric partner)
409 continue
411 # For symmetric mode: skip if the symmetric partner is already gone
412 # (unless this IS the center cell)
413 is_center = (i // 9 == 4 and i % 9 == 4)
414 symm = 9 * (8 - i // 9) + (8 - i % 9)
416 if is_symmetric and not is_center and valid[symm] == 0:
417 continue
419 # Delete cell
420 valid[i] = 0
421 remaining_clues -= 1
423 # Also delete symmetric partner (unless center cell)
424 if is_symmetric and not is_center:
425 valid[symm] = 0
426 used[symm] = True
427 used_count -= 1
428 remaining_clues -= 1
430 # Check uniqueness
431 self.solve_values(valid)
433 if self._solution_count > 1:
434 # Restore — deletion would break uniqueness
435 valid[i] = full[i]
436 remaining_clues += 1
437 if is_symmetric and not is_center:
438 valid[symm] = full[symm]
439 remaining_clues += 1
441 def _generate_init_pos_pattern(self, pattern: list[bool]) -> bool:
442 """Remove clues per a fixed pattern and check uniqueness.
444 Mirrors ``SudokuGenerator.generateInitPos(boolean[])``.
445 Returns ``True`` if the resulting puzzle has a unique solution.
446 """
447 valid = self._new_valid_sudoku
448 valid[:] = self._new_full_sudoku
450 for i in range(len(pattern)):
451 if not pattern[i]:
452 valid[i] = 0
454 self.solve_values(valid)
455 return self._solution_count <= 1
457 # ------------------------------------------------------------------
458 # Solve entry points
459 # ------------------------------------------------------------------
461 def _solve_grid(self, grid: Grid) -> None:
462 """Set up stack from an existing Grid and solve."""
463 self._stack[0].grid.set(grid)
464 self._stack[0].index = 0
465 self._stack[0].candidates = ()
466 self._stack[0].cand_index = 0
467 self._solve()
469 def solve_string(self, sudoku_string: str) -> None:
470 """Solve a puzzle given as an 81-character string."""
471 if _gen_accel is not None:
472 sol_count, sol = _gen_accel.solve_string(sudoku_string)
473 self._solution_count = sol_count
474 if sol_count >= 1:
475 self._solution[:] = sol
476 return
478 s0 = self._stack[0]
479 s0.grid.__init__() # reset to empty
480 s0.candidates = ()
481 s0.cand_index = 0
483 for i, ch in enumerate(sudoku_string[:LENGTH]):
484 value = ord(ch) - ord("0")
485 if 1 <= value <= 9:
486 s0.grid.set_cell(i, value)
487 if not _set_all_exposed_singles(s0.grid):
488 self._solution_count = 0
489 return
491 self._solve()
493 def solve_values(self, cell_values: list[int]) -> None:
494 """Solve a puzzle given as an 81-element int list.
496 Uses the fast bulk-set path (mirrors Java's ``solve(int[])``):
497 set all values without propagation, rebuild internal data, then
498 propagate singles once.
499 """
500 if _gen_accel is not None:
501 sol_count, sol = _gen_accel.solve_values(cell_values)
502 self._solution_count = sol_count
503 if sol_count >= 1:
504 self._solution[:] = sol
505 return
507 s0 = self._stack[0]
508 s0.grid.__init__() # reset to empty
509 s0.candidates = ()
510 s0.cand_index = 0
512 # Bulk set: place values directly, strip candidates from buddies
513 grid = s0.grid
514 for i, value in enumerate(cell_values):
515 if 1 <= value <= 9:
516 grid.values[i] = value
517 grid.candidates[i] = 0
518 # Remove from buddy candidates
519 buddies = BUDDIES[i]
520 while buddies:
521 lsb = buddies & -buddies
522 j = lsb.bit_length() - 1
523 buddies ^= lsb
524 grid.candidates[j] &= ~DIGIT_MASKS[value]
526 # Rebuild free counts, candidate_sets, and queues from scratch
527 _rebuild_internal(grid)
529 if not _set_all_exposed_singles(grid):
530 self._solution_count = 0
531 return
533 self._solve()
535 # ------------------------------------------------------------------
536 # Core backtracking solver
537 # ------------------------------------------------------------------
539 def _solve(self) -> None:
540 """Iterative backtracking solver.
542 Mirrors ``SudokuGenerator.solve()`` (the private no-arg version).
543 Uses the C accelerator if available, otherwise falls back to pure Python.
544 """
545 if _gen_accel is not None:
546 self._solve_c()
547 else:
548 self._solve_py()
550 def _solve_c(self) -> None:
551 """C-accelerated solve path."""
552 grid = self._stack[0].grid
553 ns_q = list(grid.ns_queue)
554 hs_q = list(grid.hs_queue)
555 sol_count, sol = _gen_accel.solve(
556 list(grid.values),
557 list(grid.candidates),
558 list(grid.candidate_sets),
559 [list(row) for row in grid.free],
560 ns_q,
561 hs_q,
562 )
563 self._solution_count = sol_count
564 if sol_count >= 1:
565 self._solution[:] = sol
567 def _solve_py(self) -> None:
568 """Pure Python iterative backtracking solver.
570 Mirrors ``SudokuGenerator.solve()`` (the private no-arg version).
571 """
572 self._solution_count = 0
573 stack = self._stack
575 # Propagate any queued singles from setup
576 if not _set_all_exposed_singles(stack[0].grid):
577 return
579 if stack[0].grid.values.count(0) == 0:
580 # Already solved
581 self._solution[:] = stack[0].grid.values
582 self._solution_count = 1
583 return
585 level = 0
586 while True:
587 unsolved = stack[level].grid.values.count(0)
588 if unsolved == 0:
589 # Found a solution
590 self._solution_count += 1
591 if self._solution_count == 1:
592 self._solution[:] = stack[level].grid.values
593 elif self._solution_count > 1:
594 return # more than one → done
595 else:
596 # Find the unsolved cell with fewest candidates (MRV)
597 index = -1
598 best_count = 9 # Java uses anzCand=9; cells with 9 cands are skipped
599 grid = stack[level].grid
600 for i in range(LENGTH):
601 cands = grid.candidates[i]
602 if cands != 0:
603 cnt = cands.bit_count()
604 if cnt < best_count:
605 best_count = cnt
606 index = i
608 level += 1
609 if index < 0:
610 # No candidates anywhere → invalid
611 self._solution_count = 0
612 return
614 stack[level].index = index
615 stack[level].candidates = _POSSIBLE_VALUES[
616 stack[level - 1].grid.candidates[index]
617 ]
618 stack[level].cand_index = 0
620 # Try candidates at this level
621 done = False
622 while True:
623 # Fall back through levels with no remaining candidates
624 while stack[level].cand_index >= len(stack[level].candidates):
625 level -= 1
626 if level <= 0:
627 done = True
628 break
629 if done:
630 break
632 # Try next candidate
633 next_cand = stack[level].candidates[stack[level].cand_index]
634 stack[level].cand_index += 1
636 # Copy parent state
637 _copy_state(stack[level].grid, stack[level - 1].grid)
639 if not _set_cell_valid(
640 stack[level].grid, stack[level].index, next_cand
641 ):
642 continue # invalid → try next candidate
644 if _set_all_exposed_singles(stack[level].grid):
645 break # valid move → advance to next level
647 if done:
648 break
651# ---------------------------------------------------------------------------
652# Internal helpers
653# ---------------------------------------------------------------------------
655def _rebuild_internal(grid: Grid) -> None:
656 """Rebuild free[], candidate_sets[], and queues from values/candidates.
658 Called after bulk-setting values without full propagation (mirrors
659 Java's ``Sudoku2.rebuildInternalData()``).
660 """
661 # Reset free counts
662 for c in range(27):
663 for d in range(10):
664 grid.free[c][d] = 0
666 # Reset candidate_sets
667 for d in range(10):
668 grid.candidate_sets[d] = 0
670 # Clear queues
671 grid.ns_queue.clear()
672 grid.hs_queue.clear()
674 for i in range(LENGTH):
675 cands = grid.candidates[i]
676 if cands == 0:
677 continue # solved cell
678 for d in range(1, 10):
679 if cands & DIGIT_MASKS[d]:
680 grid.candidate_sets[d] |= 1 << i
681 for c in CELL_CONSTRAINTS[i]:
682 grid.free[c][d] += 1
684 # Enqueue naked singles and hidden singles
685 for i in range(LENGTH):
686 cands = grid.candidates[i]
687 if cands != 0 and (cands & (cands - 1)) == 0:
688 grid.ns_queue.append((i, cands.bit_length()))
690 for c in range(27):
691 for d in range(1, 10):
692 if grid.free[c][d] == 1:
693 rem = grid.candidate_sets[d] & ALL_UNIT_MASKS[c]
694 if rem:
695 cell = (rem & -rem).bit_length() - 1
696 grid.hs_queue.append((cell, d))