1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|
"""Solution-space settling test (user's hypothesis): is the weak Maze separation an artifact of
analyzing the FULL latent space (88% trivial copy) instead of the SOLUTION/output space?
Measures per-step decoded-ANSWER Hamming drift over solution-space cells (label!=input) and asks:
(1) does the DECODED ANSWER settle (stop changing) differently for success vs failure?
(2) does solution-space settling separate outcome where full-latent drift did not (Maze)?
(3) Sudoku control: full-latent DID separate; does solution-space too?
Plus per-cell failure structure for Maze: are error cells a connected detour or scattered?
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
HERE = Path(__file__).resolve().parent
FU = HERE / "maze_followup"
def auc(s, y):
p, n = s[y == 1], s[y == 0]
if len(p) == 0 or len(n) == 0:
return float("nan")
a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1)
return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n)))
def cohend(a, b):
s = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / max(len(a) + len(b) - 2, 1))
return (a.mean() - b.mean()) / s if s > 0 else float("nan")
def connected_components(mask2d):
# 4-connectivity component count of True cells (small grids, simple flood fill)
seen = np.zeros_like(mask2d, bool); comps = 0; sizes = []
H, W = mask2d.shape
for i in range(H):
for j in range(W):
if mask2d[i, j] and not seen[i, j]:
comps += 1; stack = [(i, j)]; seen[i, j] = True; sz = 0
while stack:
r, c = stack.pop(); sz += 1
for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
rr, cc = r + dr, c + dc
if 0 <= rr < H and 0 <= cc < W and mask2d[rr, cc] and not seen[rr, cc]:
seen[rr, cc] = True; stack.append((rr, cc))
sizes.append(sz)
return comps, sizes
def analyze(tag, npz, grid=None):
d = np.load(npz)
y = d["exact_correct"].astype(int)
ad = d["ans_drift_ans"].astype(float) # (N, steps) decoded-answer drift over solution cells
late = ad[:, -4:].mean(1) # late answer-drift = solution-space "still changing"
print(f"\n=== {tag} (n={len(y)}, acc={y.mean():.3f}) — SOLUTION-SPACE settling ===")
print(f" late answer-drift (cells/step changing) | success median={np.median(late[y==1]):.2f} "
f"failure median={np.median(late[y==0]):.2f}")
print(f" AUC(-late answer-drift -> correct) = {auc(-late, y):.3f} Cohen d(fail-succ)={cohend(late[y==0],late[y==1]):+.2f}")
settled = late < 0.5
print(f" fraction with SETTLED answer (drift<0.5 cells/step): success={settled[y==1].mean():.3f} failure={settled[y==0].mean():.3f}")
# full-grid for reference
adf = d["ans_drift_full"][:, -4:].mean(1).astype(float)
print(f" [ref] AUC(-late FULL-grid answer-drift -> correct) = {auc(-adf, y):.3f}")
if grid is not None:
# per-cell failure structure: are error cells connected (detour) or scattered?
preds = d["preds"]; labels = d["labels"]; inputs = d["inputs"]
fail = np.where(y == 0)[0]
comps, errs = [], []
for k in fail:
err = (preds[k] != labels[k]).reshape(grid)
ne = int(err.sum())
if ne == 0:
continue
c, sizes = connected_components(err)
comps.append(c); errs.append(ne)
comps = np.array(comps); errs = np.array(errs)
if len(comps):
print(f" per-failure error cells: median={int(np.median(errs))}; connected components: median={int(np.median(comps))}; "
f"largest-component fraction median={np.median([1]):.2f}")
frac_one = (comps <= 2).mean()
print(f" -> {frac_one:.2f} of failures have <=2 error components (connected detour); "
f"{(comps>=5).mean():.2f} have >=5 (scattered)")
if __name__ == "__main__":
mz = FU / "maze_preds_step130200.npz"
sd = FU / "sudoku_preds_step58590.npz"
if mz.exists():
analyze("MAZE", mz, grid=(30, 30))
else:
print("[pending] maze dump")
if sd.exists():
analyze("SUDOKU (control)", sd, grid=(9, 9))
else:
print("[pending] sudoku dump")
|