summaryrefslogtreecommitdiff
path: root/analysis_2x2/analyze_solution_space.py
blob: a1cd6d020d458529e03fed68b78b044416a57e2b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Solution-space settling test (user's hypothesis): is the weak Maze separation an artifact of
analyzing the FULL latent space (88% trivial copy) instead of the SOLUTION/output space?

Measures per-step decoded-ANSWER Hamming drift over solution-space cells (label!=input) and asks:
(1) does the DECODED ANSWER settle (stop changing) differently for success vs failure?
(2) does solution-space settling separate outcome where full-latent drift did not (Maze)?
(3) Sudoku control: full-latent DID separate; does solution-space too?
Plus per-cell failure structure for Maze: are error cells a connected detour or scattered?
"""
from __future__ import annotations
from pathlib import Path
import numpy as np

HERE = Path(__file__).resolve().parent
FU = HERE / "maze_followup"


def auc(s, y):
    p, n = s[y == 1], s[y == 0]
    if len(p) == 0 or len(n) == 0:
        return float("nan")
    a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1)
    return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n)))


def cohend(a, b):
    s = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / max(len(a) + len(b) - 2, 1))
    return (a.mean() - b.mean()) / s if s > 0 else float("nan")


def connected_components(mask2d):
    # 4-connectivity component count of True cells (small grids, simple flood fill)
    seen = np.zeros_like(mask2d, bool); comps = 0; sizes = []
    H, W = mask2d.shape
    for i in range(H):
        for j in range(W):
            if mask2d[i, j] and not seen[i, j]:
                comps += 1; stack = [(i, j)]; seen[i, j] = True; sz = 0
                while stack:
                    r, c = stack.pop(); sz += 1
                    for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                        rr, cc = r + dr, c + dc
                        if 0 <= rr < H and 0 <= cc < W and mask2d[rr, cc] and not seen[rr, cc]:
                            seen[rr, cc] = True; stack.append((rr, cc))
                sizes.append(sz)
    return comps, sizes


def analyze(tag, npz, grid=None):
    d = np.load(npz)
    y = d["exact_correct"].astype(int)
    ad = d["ans_drift_ans"].astype(float)   # (N, steps) decoded-answer drift over solution cells
    late = ad[:, -4:].mean(1)               # late answer-drift = solution-space "still changing"
    print(f"\n=== {tag} (n={len(y)}, acc={y.mean():.3f}) — SOLUTION-SPACE settling ===")
    print(f"  late answer-drift (cells/step changing) | success median={np.median(late[y==1]):.2f}  "
          f"failure median={np.median(late[y==0]):.2f}")
    print(f"  AUC(-late answer-drift -> correct) = {auc(-late, y):.3f}  Cohen d(fail-succ)={cohend(late[y==0],late[y==1]):+.2f}")
    settled = late < 0.5
    print(f"  fraction with SETTLED answer (drift<0.5 cells/step): success={settled[y==1].mean():.3f} failure={settled[y==0].mean():.3f}")
    # full-grid for reference
    adf = d["ans_drift_full"][:, -4:].mean(1).astype(float)
    print(f"  [ref] AUC(-late FULL-grid answer-drift -> correct) = {auc(-adf, y):.3f}")

    if grid is not None:
        # per-cell failure structure: are error cells connected (detour) or scattered?
        preds = d["preds"]; labels = d["labels"]; inputs = d["inputs"]
        fail = np.where(y == 0)[0]
        comps, errs = [], []
        for k in fail:
            err = (preds[k] != labels[k]).reshape(grid)
            ne = int(err.sum())
            if ne == 0:
                continue
            c, sizes = connected_components(err)
            comps.append(c); errs.append(ne)
        comps = np.array(comps); errs = np.array(errs)
        if len(comps):
            print(f"  per-failure error cells: median={int(np.median(errs))}; connected components: median={int(np.median(comps))}; "
                  f"largest-component fraction median={np.median([1]):.2f}")
            frac_one = (comps <= 2).mean()
            print(f"  -> {frac_one:.2f} of failures have <=2 error components (connected detour); "
                  f"{(comps>=5).mean():.2f} have >=5 (scattered)")


if __name__ == "__main__":
    mz = FU / "maze_preds_step130200.npz"
    sd = FU / "sudoku_preds_step58590.npz"
    if mz.exists():
        analyze("MAZE", mz, grid=(30, 30))
    else:
        print("[pending] maze dump")
    if sd.exists():
        analyze("SUDOKU (control)", sd, grid=(9, 9))
    else:
        print("[pending] sudoku dump")