summaryrefslogtreecommitdiff
path: root/analysis_2x2/analyze_maze_connectivity.py
blob: f2fb01b6eec20cddab26b28d9a631294f9f6ab8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Maze with CONNECTIVITY as the success criterion (not exact-match).
Genuine failure = predicted path does NOT connect start->goal (broken/incomplete answer).
Valid alternative paths (connected but != labeled) count as 'complete answer'.

Joins per-cell preds (my dump, seed 20260616) with the friend's FTLE/drift npz (same seed/idx),
and asks: do BROKEN (disconnected) predictions WANDER while CONNECTED ones SETTLE?
If yes, the dynamical signal tracks answer-COMPLETENESS, and exact-match was the wrong lens for Maze.
"""
from __future__ import annotations
from pathlib import Path
from collections import deque
import glob
import numpy as np

HERE = Path(__file__).resolve().parent
FU = HERE / "maze_followup"
FRIEND = "/tmp/friend_maze/maze_all_ckpts_lyap"


def is_connected(inp, pred):
    g = inp.reshape(30, 30); pr = pred.reshape(30, 30)
    se = np.argwhere((g == 3) | (g == 4))
    if len(se) < 2:
        return True  # can't judge -> treat as connected (won't happen)
    s, e = tuple(se[0]), tuple(se[1])
    pathset = set(map(tuple, np.argwhere(pr == 5))) | {s, e}
    if any(g[r, c] == 1 for r, c in pathset):
        return False  # crosses wall = invalid
    seen = {s}; q = deque([s])
    while q:
        r, c = q.popleft()
        if (r, c) == e:
            return True
        for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
            nr, nc = r + dr, c + dc
            if 0 <= nr < 30 and 0 <= nc < 30 and (nr, nc) in pathset and (nr, nc) not in seen:
                seen.add((nr, nc)); q.append((nr, nc))
    return False


def auc(score, y):
    p, n = score[y == 1], score[y == 0]
    if len(p) == 0 or len(n) == 0:
        return float("nan")
    a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1)
    return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n)))


def cohend(a, b):
    if len(a) < 2 or len(b) < 2:
        return float("nan")
    s = np.sqrt(((len(a) - 1) * a.var(ddof=1) + (len(b) - 1) * b.var(ddof=1)) / (len(a) + len(b) - 2))
    return (a.mean() - b.mean()) / s if s > 0 else float("nan")


for step in [13020, 52080, 130200]:
    pred_f = FU / f"mazepreds_step_{step}_seed20260616.npz"
    fr = glob.glob(f"{FRIEND}/maze_step_{step}_*.npz")
    if not pred_f.exists() or not fr:
        print(f"[pending] step {step}")
        continue
    P = np.load(pred_f); F = np.load(fr[0])
    common, pi, fi = np.intersect1d(P["idx"], F["idx"], return_indices=True)
    preds = P["preds"][pi]; inputs = P["inputs"][pi]
    exact = P["exact_correct"][pi].astype(int)
    l1 = F["lyap_spec"][fi, 0].astype(float)
    late_drift = np.log10(np.clip(F["drift_zH"][fi, -4:].mean(1), 1e-12, None))
    conn = np.array([is_connected(inputs[k], preds[k]) for k in range(len(common))]).astype(int)
    nb = int((conn == 0).sum())
    print(f"\n=== step {step} (joined n={len(common)}) ===")
    print(f"  exact-match acc={exact.mean():.3f} | CONNECTIVITY acc (valid complete path)={conn.mean():.3f} | broken={nb}")
    if nb < 3 or nb > len(common) - 3:
        print(f"  too few broken/connected to condition dynamics (broken={nb})")
        continue
    # dynamics conditioned on CONNECTIVITY (broken=0 vs connected=1)
    print(f"  late-drift (settling): connected median={np.median(late_drift[conn==1]):.2f}  broken median={np.median(late_drift[conn==0]):.2f}")
    print(f"    AUC(-late-drift -> connected) = {auc(-late_drift, conn):.3f}  Cohen d(broken-conn)={cohend(late_drift[conn==0], late_drift[conn==1]):+.2f}")
    print(f"  lambda1: connected median={np.median(l1[conn==1]):+.4f}  broken median={np.median(l1[conn==0]):+.4f}")
    print(f"    AUC(-lambda1 -> connected) = {auc(-l1, conn):.3f}")
    # compare: does connectivity separate dynamics BETTER than exact-match?
    print(f"  [vs exact-match] AUC(-late-drift -> exact_correct) = {auc(-late_drift, exact):.3f}, "
          f"AUC(-lambda1 -> exact) = {auc(-l1, exact):.3f}")
    # within CONNECTED, does exact-match still separate? (should NOT, if dynamics track completeness)
    m = conn == 1
    if 0 < exact[m].mean() < 1:
        print(f"  within CONNECTED (n={m.sum()}): AUC(-late-drift -> exact) = {auc(-late_drift[m], exact[m]):.3f} "
              f"(near 0.5 => dynamics track completeness, not correctness)")