diff options
Diffstat (limited to 'analysis_2x2/analyze_solution_space.py')
| -rw-r--r-- | analysis_2x2/analyze_solution_space.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/analysis_2x2/analyze_solution_space.py b/analysis_2x2/analyze_solution_space.py new file mode 100644 index 0000000..a1cd6d0 --- /dev/null +++ b/analysis_2x2/analyze_solution_space.py @@ -0,0 +1,95 @@ +"""Solution-space settling test (user's hypothesis): is the weak Maze separation an artifact of +analyzing the FULL latent space (88% trivial copy) instead of the SOLUTION/output space? + +Measures per-step decoded-ANSWER Hamming drift over solution-space cells (label!=input) and asks: +(1) does the DECODED ANSWER settle (stop changing) differently for success vs failure? +(2) does solution-space settling separate outcome where full-latent drift did not (Maze)? +(3) Sudoku control: full-latent DID separate; does solution-space too? +Plus per-cell failure structure for Maze: are error cells a connected detour or scattered? +""" +from __future__ import annotations +from pathlib import Path +import numpy as np + +HERE = Path(__file__).resolve().parent +FU = HERE / "maze_followup" + + +def auc(s, y): + p, n = s[y == 1], s[y == 0] + if len(p) == 0 or len(n) == 0: + return float("nan") + a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1) + return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n))) + + +def cohend(a, b): + s = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / max(len(a) + len(b) - 2, 1)) + return (a.mean() - b.mean()) / s if s > 0 else float("nan") + + +def connected_components(mask2d): + # 4-connectivity component count of True cells (small grids, simple flood fill) + seen = np.zeros_like(mask2d, bool); comps = 0; sizes = [] + H, W = mask2d.shape + for i in range(H): + for j in range(W): + if mask2d[i, j] and not seen[i, j]: + comps += 1; stack = [(i, j)]; seen[i, j] = True; sz = 0 + while stack: + r, c = stack.pop(); sz += 1 + for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]: + rr, cc = r + dr, c + dc + if 0 <= rr < H and 0 <= cc < W and mask2d[rr, cc] and not seen[rr, cc]: + seen[rr, cc] = True; stack.append((rr, cc)) + sizes.append(sz) + return comps, sizes + + +def analyze(tag, npz, grid=None): + d = np.load(npz) + y = d["exact_correct"].astype(int) + ad = d["ans_drift_ans"].astype(float) # (N, steps) decoded-answer drift over solution cells + late = ad[:, -4:].mean(1) # late answer-drift = solution-space "still changing" + print(f"\n=== {tag} (n={len(y)}, acc={y.mean():.3f}) — SOLUTION-SPACE settling ===") + print(f" late answer-drift (cells/step changing) | success median={np.median(late[y==1]):.2f} " + f"failure median={np.median(late[y==0]):.2f}") + print(f" AUC(-late answer-drift -> correct) = {auc(-late, y):.3f} Cohen d(fail-succ)={cohend(late[y==0],late[y==1]):+.2f}") + settled = late < 0.5 + print(f" fraction with SETTLED answer (drift<0.5 cells/step): success={settled[y==1].mean():.3f} failure={settled[y==0].mean():.3f}") + # full-grid for reference + adf = d["ans_drift_full"][:, -4:].mean(1).astype(float) + print(f" [ref] AUC(-late FULL-grid answer-drift -> correct) = {auc(-adf, y):.3f}") + + if grid is not None: + # per-cell failure structure: are error cells connected (detour) or scattered? + preds = d["preds"]; labels = d["labels"]; inputs = d["inputs"] + fail = np.where(y == 0)[0] + comps, errs = [], [] + for k in fail: + err = (preds[k] != labels[k]).reshape(grid) + ne = int(err.sum()) + if ne == 0: + continue + c, sizes = connected_components(err) + comps.append(c); errs.append(ne) + comps = np.array(comps); errs = np.array(errs) + if len(comps): + print(f" per-failure error cells: median={int(np.median(errs))}; connected components: median={int(np.median(comps))}; " + f"largest-component fraction median={np.median([1]):.2f}") + frac_one = (comps <= 2).mean() + print(f" -> {frac_one:.2f} of failures have <=2 error components (connected detour); " + f"{(comps>=5).mean():.2f} have >=5 (scattered)") + + +if __name__ == "__main__": + mz = FU / "maze_preds_step130200.npz" + sd = FU / "sudoku_preds_step58590.npz" + if mz.exists(): + analyze("MAZE", mz, grid=(30, 30)) + else: + print("[pending] maze dump") + if sd.exists(): + analyze("SUDOKU (control)", sd, grid=(9, 9)) + else: + print("[pending] sudoku dump") |
