summaryrefslogtreecommitdiff
path: root/analysis_2x2/analyze_solution_space.py
diff options
context:
space:
mode:
Diffstat (limited to 'analysis_2x2/analyze_solution_space.py')
-rw-r--r--analysis_2x2/analyze_solution_space.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/analysis_2x2/analyze_solution_space.py b/analysis_2x2/analyze_solution_space.py
new file mode 100644
index 0000000..a1cd6d0
--- /dev/null
+++ b/analysis_2x2/analyze_solution_space.py
@@ -0,0 +1,95 @@
+"""Solution-space settling test (user's hypothesis): is the weak Maze separation an artifact of
+analyzing the FULL latent space (88% trivial copy) instead of the SOLUTION/output space?
+
+Measures per-step decoded-ANSWER Hamming drift over solution-space cells (label!=input) and asks:
+(1) does the DECODED ANSWER settle (stop changing) differently for success vs failure?
+(2) does solution-space settling separate outcome where full-latent drift did not (Maze)?
+(3) Sudoku control: full-latent DID separate; does solution-space too?
+Plus per-cell failure structure for Maze: are error cells a connected detour or scattered?
+"""
+from __future__ import annotations
+from pathlib import Path
+import numpy as np
+
+HERE = Path(__file__).resolve().parent
+FU = HERE / "maze_followup"
+
+
+def auc(s, y):
+ p, n = s[y == 1], s[y == 0]
+ if len(p) == 0 or len(n) == 0:
+ return float("nan")
+ a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1)
+ return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n)))
+
+
+def cohend(a, b):
+ s = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / max(len(a) + len(b) - 2, 1))
+ return (a.mean() - b.mean()) / s if s > 0 else float("nan")
+
+
+def connected_components(mask2d):
+ # 4-connectivity component count of True cells (small grids, simple flood fill)
+ seen = np.zeros_like(mask2d, bool); comps = 0; sizes = []
+ H, W = mask2d.shape
+ for i in range(H):
+ for j in range(W):
+ if mask2d[i, j] and not seen[i, j]:
+ comps += 1; stack = [(i, j)]; seen[i, j] = True; sz = 0
+ while stack:
+ r, c = stack.pop(); sz += 1
+ for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
+ rr, cc = r + dr, c + dc
+ if 0 <= rr < H and 0 <= cc < W and mask2d[rr, cc] and not seen[rr, cc]:
+ seen[rr, cc] = True; stack.append((rr, cc))
+ sizes.append(sz)
+ return comps, sizes
+
+
+def analyze(tag, npz, grid=None):
+ d = np.load(npz)
+ y = d["exact_correct"].astype(int)
+ ad = d["ans_drift_ans"].astype(float) # (N, steps) decoded-answer drift over solution cells
+ late = ad[:, -4:].mean(1) # late answer-drift = solution-space "still changing"
+ print(f"\n=== {tag} (n={len(y)}, acc={y.mean():.3f}) — SOLUTION-SPACE settling ===")
+ print(f" late answer-drift (cells/step changing) | success median={np.median(late[y==1]):.2f} "
+ f"failure median={np.median(late[y==0]):.2f}")
+ print(f" AUC(-late answer-drift -> correct) = {auc(-late, y):.3f} Cohen d(fail-succ)={cohend(late[y==0],late[y==1]):+.2f}")
+ settled = late < 0.5
+ print(f" fraction with SETTLED answer (drift<0.5 cells/step): success={settled[y==1].mean():.3f} failure={settled[y==0].mean():.3f}")
+ # full-grid for reference
+ adf = d["ans_drift_full"][:, -4:].mean(1).astype(float)
+ print(f" [ref] AUC(-late FULL-grid answer-drift -> correct) = {auc(-adf, y):.3f}")
+
+ if grid is not None:
+ # per-cell failure structure: are error cells connected (detour) or scattered?
+ preds = d["preds"]; labels = d["labels"]; inputs = d["inputs"]
+ fail = np.where(y == 0)[0]
+ comps, errs = [], []
+ for k in fail:
+ err = (preds[k] != labels[k]).reshape(grid)
+ ne = int(err.sum())
+ if ne == 0:
+ continue
+ c, sizes = connected_components(err)
+ comps.append(c); errs.append(ne)
+ comps = np.array(comps); errs = np.array(errs)
+ if len(comps):
+ print(f" per-failure error cells: median={int(np.median(errs))}; connected components: median={int(np.median(comps))}; "
+ f"largest-component fraction median={np.median([1]):.2f}")
+ frac_one = (comps <= 2).mean()
+ print(f" -> {frac_one:.2f} of failures have <=2 error components (connected detour); "
+ f"{(comps>=5).mean():.2f} have >=5 (scattered)")
+
+
+if __name__ == "__main__":
+ mz = FU / "maze_preds_step130200.npz"
+ sd = FU / "sudoku_preds_step58590.npz"
+ if mz.exists():
+ analyze("MAZE", mz, grid=(30, 30))
+ else:
+ print("[pending] maze dump")
+ if sd.exists():
+ analyze("SUDOKU (control)", sd, grid=(9, 9))
+ else:
+ print("[pending] sudoku dump")