Coverage for src / hodoku / solver / table_entry.py: 88%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 08:35 +0000

1"""TableEntry — one row in the tabling implication table. 

2 

3Mirrors Java's TableEntry class. Each TableEntry holds all transitive 

4implications of a single premise ("candidate N is set/deleted in cell X"). 

5 

6The parallel arrays ``entries[]`` and ``ret_indices[]`` are synchronised: 

7``entries[k]`` is a 32-bit packed chain entry (see chain_utils.py) and 

8``ret_indices[k]`` is a 64-bit backpointer used for chain reconstruction. 

9 

10``on_sets[d]`` / ``off_sets[d]`` (d 1-9) are 81-bit bitmasks summarising 

11which cells can be set to / deleted from digit d as a result of the premise. 

12""" 

13 

14from __future__ import annotations 

15 

16from hodoku.solver.chain_utils import ( 

17 NORMAL_NODE, 

18 get_candidate, 

19 get_cell_index, 

20 get_node_type, 

21 is_strong, 

22 make_entry, 

23) 

24 

25# --------------------------------------------------------------------------- 

26# retIndex flag bits (64-bit) 

27# --------------------------------------------------------------------------- 

28 

29_EXPANDED: int = 1 << 61 

30_ON_TABLE: int = 1 << 62 

31_EXTENDED_TABLE: int = 1 << 63 

32 

33# Mask that preserves the three flag bits (63, 62, 61) and the lower 52 bits, 

34# clearing only the 9 distance bits (52-60). 

35_DIST_CLEAR_MASK: int = 0xE00FFFFFFFFFFFFF 

36 

37MAX_TABLE_ENTRY_LENGTH: int = 1000 

38 

39 

40class TableEntry: 

41 """One premise's implication table row.""" 

42 

43 __slots__ = ( 

44 "index", 

45 "entries", 

46 "ret_indices", 

47 "on_sets", 

48 "off_sets", 

49 "indices", 

50 ) 

51 

52 def __init__(self) -> None: 

53 self.index: int = 0 

54 self.entries: list[int] = [0] * MAX_TABLE_ENTRY_LENGTH 

55 self.ret_indices: list[int] = [0] * MAX_TABLE_ENTRY_LENGTH 

56 # on_sets[d] = bitmask of cells that can be SET to d; index 0 unused 

57 self.on_sets: list[int] = [0] * 10 

58 # off_sets[d] = bitmask of cells from which d can be DELETED 

59 self.off_sets: list[int] = [0] * 10 

60 # Reverse lookup: entry value -> index in entries[] 

61 self.indices: dict[int, int] = {} 

62 

63 def reset(self) -> None: 

64 """Clear the table for reuse.""" 

65 self.index = 0 

66 self.indices.clear() 

67 for i in range(10): 

68 self.on_sets[i] = 0 

69 self.off_sets[i] = 0 

70 entries = self.entries 

71 ret_indices = self.ret_indices 

72 for i in range(len(entries)): 

73 entries[i] = 0 

74 ret_indices[i] = 0 

75 

76 # ------------------------------------------------------------------ 

77 # Adding entries 

78 # ------------------------------------------------------------------ 

79 

80 def add_entry( 

81 self, 

82 cell_index1: int, 

83 cand: int, 

84 is_set: bool, 

85 cell_index2: int = -1, 

86 cell_index3: int = -1, 

87 node_type: int = NORMAL_NODE, 

88 ri1: int = 0, 

89 ri2: int = 0, 

90 ri3: int = 0, 

91 ri4: int = 0, 

92 ri5: int = 0, 

93 penalty: int = 0, 

94 ) -> None: 

95 """Add an implication to this table. 

96 

97 Convenience wrapper that mirrors all Java addEntry overloads via 

98 keyword arguments. 

99 """ 

100 idx = self.index 

101 if idx >= len(self.entries): 

102 return 

103 

104 # Dedup: for normal nodes, skip if already recorded in on/off sets 

105 if node_type == NORMAL_NODE: 

106 if is_set: 

107 if self.on_sets[cand] & (1 << cell_index1): 

108 return 

109 else: 

110 if self.off_sets[cand] & (1 << cell_index1): 

111 return 

112 

113 # Construct the 32-bit packed entry 

114 entry = make_entry(cell_index1, cell_index2, cell_index3, 

115 cand, is_set, node_type) 

116 self.entries[idx] = entry 

117 self.ret_indices[idx] = _make_ret_index(ri1, ri2, ri3, ri4, ri5) 

118 

119 # Set distance — for initial entries ri1 is the predecessor index 

120 # in THIS table. For expanded entries the caller overrides immediately. 

121 if ri1 < len(self.ret_indices): 

122 self._set_distance(idx, self.get_distance(ri1) + 1) 

123 

124 # Update summary bitmasks (only for normal nodes) 

125 if node_type == NORMAL_NODE: 

126 if is_set: 

127 self.on_sets[cand] |= 1 << cell_index1 

128 else: 

129 self.off_sets[cand] |= 1 << cell_index1 

130 

131 # ALS penalty 

132 if penalty: 

133 self._set_distance(idx, self.get_distance(idx) + penalty) 

134 

135 self.indices[entry] = idx 

136 self.index = idx + 1 

137 

138 # ------------------------------------------------------------------ 

139 # Convenience wrappers matching Java's overload patterns 

140 # ------------------------------------------------------------------ 

141 

142 def add_entry_simple(self, cell_index: int, cand: int, is_set: bool) -> None: 

143 """Simple node, no reverse index. Used for initial table filling.""" 

144 self.add_entry(cell_index, cand, is_set) 

145 

146 def add_entry_with_ri(self, cell_index: int, cand: int, is_set: bool, 

147 reverse_index: int) -> None: 

148 """Simple node with one reverse index. Used during expansion.""" 

149 self.add_entry(cell_index, cand, is_set, ri1=reverse_index) 

150 

151 # ------------------------------------------------------------------ 

152 # Entry accessors (delegate to chain_utils on entries[idx]) 

153 # ------------------------------------------------------------------ 

154 

155 def get_cell_index(self, idx: int) -> int: 

156 return get_cell_index(self.entries[idx]) 

157 

158 def is_strong(self, idx: int) -> bool: 

159 return is_strong(self.entries[idx]) 

160 

161 def get_candidate(self, idx: int) -> int: 

162 return get_candidate(self.entries[idx]) 

163 

164 def get_node_type(self, idx: int) -> int: 

165 return get_node_type(self.entries[idx]) 

166 

167 def is_full(self) -> bool: 

168 return self.index >= len(self.entries) 

169 

170 # ------------------------------------------------------------------ 

171 # Entry lookup 

172 # ------------------------------------------------------------------ 

173 

174 def get_entry_index_by_value(self, entry: int) -> int: 

175 """Find the slot index for a given 32-bit entry value. 

176 

177 Returns the index, or raises KeyError if not found. 

178 """ 

179 return self.indices[entry] 

180 

181 def get_entry_index(self, cell_index: int, is_set: bool, cand: int) -> int: 

182 """Find the slot index for cell/cand/set combination. 

183 

184 Returns 0 if not found (matching Java behaviour). 

185 """ 

186 from hodoku.solver.chain_utils import make_entry_simple 

187 entry = make_entry_simple(cell_index, cand, is_set) 

188 return self.indices.get(entry, 0) 

189 

190 # ------------------------------------------------------------------ 

191 # retIndex bit manipulation 

192 # ------------------------------------------------------------------ 

193 

194 def get_distance(self, idx: int) -> int: 

195 """Distance from root (bits 52-60, 9 bits).""" 

196 return _get_ret_index(self.ret_indices[idx], 5) & 0x1FF 

197 

198 def _set_distance(self, idx: int, distance: int) -> None: 

199 """Set distance in ret_indices[idx].""" 

200 d = distance & 0x1FF 

201 self.ret_indices[idx] &= _DIST_CLEAR_MASK 

202 self.ret_indices[idx] |= d << 52 

203 

204 def is_expanded(self, idx: int) -> bool: 

205 """Entry was merged from another table during expansion.""" 

206 return (self.ret_indices[idx] & _EXPANDED) != 0 

207 

208 def set_expanded(self, idx: int) -> None: 

209 self.ret_indices[idx] |= _EXPANDED 

210 

211 def is_on_table(self, idx: int) -> bool: 

212 """Source was on_table (vs off_table).""" 

213 return (self.ret_indices[idx] & _ON_TABLE) != 0 

214 

215 def set_on_table(self, idx: int) -> None: 

216 self.ret_indices[idx] |= _ON_TABLE 

217 

218 def is_extended_table(self, idx: int) -> bool: 

219 """Source was extended_table.""" 

220 return (self.ret_indices[idx] & _EXTENDED_TABLE) != 0 

221 

222 def set_extended_table(self, idx: int) -> None: 

223 self.ret_indices[idx] |= _EXTENDED_TABLE 

224 

225 def set_extended_table_last(self) -> None: 

226 """Mark the most recently added entry as from extended_table.""" 

227 self.ret_indices[self.index - 1] |= _EXTENDED_TABLE 

228 

229 def get_ret_index(self, idx: int, which: int) -> int: 

230 """Get reverse index ``which`` (0-4) or distance (5) from slot ``idx``.""" 

231 return _get_ret_index(self.ret_indices[idx], which) 

232 

233 def get_ret_index_count(self, idx: int) -> int: 

234 """Number of reverse indices in slot ``idx`` (1-5).""" 

235 return _get_ret_index_count(self.ret_indices[idx]) 

236 

237 

238# --------------------------------------------------------------------------- 

239# Module-level helpers (matching Java's static methods) 

240# --------------------------------------------------------------------------- 

241 

242def _make_ret_index(i1: int, i2: int, i3: int, i4: int, i5: int) -> int: 

243 """Pack up to 5 reverse indices into a 64-bit retIndex value. 

244 

245 The largest value is placed in the first (12-bit) slot. 

246 Matches TableEntry.makeSRetIndex in Java. 

247 """ 

248 # Clamp to field widths 

249 if i1 > 4096: 

250 i1 = 0 

251 if i2 > 1023: 

252 i2 = 0 

253 if i3 > 1023: 

254 i3 = 0 

255 if i4 > 1023: 

256 i4 = 0 

257 if i5 > 1023: 

258 i5 = 0 

259 

260 # Ensure the largest value is first 

261 if i2 > i1: 

262 i1, i2 = i2, i1 

263 if i3 > i1: 

264 i1, i3 = i3, i1 

265 if i4 > i1: 

266 i1, i4 = i4, i1 

267 if i5 > i1: 

268 i1, i5 = i5, i1 

269 

270 return (i5 << 42) + (i4 << 32) + (i3 << 22) + (i2 << 12) + i1 

271 

272 

273def _get_ret_index(ret_index: int, which: int) -> int: 

274 """Extract reverse index ``which`` (0-4) or distance (5) from a retIndex. 

275 

276 Matches TableEntry.getSRetIndex in Java. 

277 """ 

278 if which == 0: 

279 return ret_index & 0xFFF 

280 ret = (ret_index >> (which * 10 + 2)) & 0x3FF 

281 if which == 5: 

282 ret &= 0x1FF 

283 return ret 

284 

285 

286def _get_ret_index_count(ret_index: int) -> int: 

287 """Count how many reverse indices are present (1-5). 

288 

289 The first is always present (even if 0). 

290 Matches TableEntry.getSRetIndexAnz in Java. 

291 """ 

292 count = 1 

293 ri = ret_index >> 12 

294 for _ in range(4): 

295 if ri & 0x3FF: 

296 count += 1 

297 ri >>= 10 

298 return count