summaryrefslogtreecommitdiff
path: root/analysis_2x2/analyze_early_connectivity.py
blob: c74023af84f336c4d6d453b27aea7e4239f6d1aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Pooled broken-vs-complete dynamics across EARLY maze checkpoints (large-n robust version).
Reads maze_earlysave dumps (each has preds, inputs, drift_zH, ans_drift, exact via maze_pred_dump).
Connectivity = success criterion. Tests: do BROKEN (incomplete) predictions have higher latent
drift (more chaotic / non-settling) than CONNECTED (complete) ones? Pools across checkpoints for n.
"""
from __future__ import annotations
from pathlib import Path
from collections import deque
import glob
import numpy as np

HERE = Path(__file__).resolve().parent
DUMPS = HERE / "maze_followup"
rng = np.random.default_rng(0)


def is_connected(inp, pred):
    g = inp.reshape(30, 30); pr = pred.reshape(30, 30)
    se = np.argwhere((g == 3) | (g == 4))
    if len(se) < 2:
        return True
    s, e = tuple(se[0]), tuple(se[1])
    ps = set(map(tuple, np.argwhere(pr == 5))) | {s, e}
    if any(g[r, c] == 1 for r, c in ps):
        return False
    seen = {s}; q = deque([s])
    while q:
        r, c = q.popleft()
        if (r, c) == e:
            return True
        for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
            nr, nc = r + dr, c + dc
            if 0 <= nr < 30 and 0 <= nc < 30 and (nr, nc) in ps and (nr, nc) not in seen:
                seen.add((nr, nc)); q.append((nr, nc))
    return False


def auc(score, y):
    p, n = score[y == 1], score[y == 0]
    if len(p) == 0 or len(n) == 0:
        return float("nan")
    a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1)
    return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n)))


def main():
    files = sorted(glob.glob(str(DUMPS / "earlydump_step_*.npz")),
                   key=lambda f: int(f.split("step_")[1].split("_")[0].split(".")[0]))
    if not files:
        print("[no earlydump files yet]"); return
    pool_ld, pool_ad, pool_conn = [], [], []
    print(f"{'ckpt':>10} {'connAcc':>8} {'broken':>7} {'AUC(ldrift)':>11} {'AUC(ansdrift)':>13}")
    for f in files:
        d = np.load(f)
        step = f.split("step_")[1].split(".")[0]
        conn = np.array([is_connected(d["inputs"][k], d["preds"][k]) for k in range(len(d["preds"]))]).astype(int)
        ld = np.log10(np.clip(d["drift_zH"][:, -4:].mean(1), 1e-12, None))
        ad = d["ans_drift_ans"][:, -4:].mean(1).astype(float)
        nb = int((conn == 0).sum())
        a_ld = auc(-ld, conn) if 3 <= nb <= len(conn) - 3 else float("nan")
        a_ad = auc(-ad, conn) if 3 <= nb <= len(conn) - 3 else float("nan")
        print(f"{step:>10} {conn.mean():>8.3f} {nb:>7} {a_ld:>11.3f} {a_ad:>13.3f}")
        pool_ld.append(ld); pool_ad.append(ad); pool_conn.append(conn)
    ld = np.concatenate(pool_ld); ad = np.concatenate(pool_ad); conn = np.concatenate(pool_conn)
    nb = int((conn == 0).sum())
    a_ld = auc(-ld, conn); a_ad = auc(-ad, conn)
    boot = [auc(-ld[i], conn[i]) for i in (rng.integers(0, len(conn), len(conn)) for _ in range(5000))]
    boot = [b for b in boot if not np.isnan(b)]
    print(f"\nPOOLED: n={len(conn)}, broken={nb}, connectivity-acc={conn.mean():.3f}")
    print(f"  latent-drift: broken median={np.median(ld[conn==0]):.2f} vs connected={np.median(ld[conn==1]):.2f}")
    print(f"  AUC(-latent_drift -> connected/complete) = {a_ld:.3f}  bootstrap95%CI=[{np.percentile(boot,2.5):.3f}, {np.percentile(boot,97.5):.3f}]")
    print(f"  AUC(-answer_drift -> connected) = {a_ad:.3f}")
    print(f"  => robust 'broken/incomplete = more chaotic'? CI excludes 0.5: {np.percentile(boot,2.5) > 0.5}")


if __name__ == "__main__":
    main()