1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
"""Pooled broken-vs-complete dynamics across EARLY maze checkpoints (large-n robust version).
Reads maze_earlysave dumps (each has preds, inputs, drift_zH, ans_drift, exact via maze_pred_dump).
Connectivity = success criterion. Tests: do BROKEN (incomplete) predictions have higher latent
drift (more chaotic / non-settling) than CONNECTED (complete) ones? Pools across checkpoints for n.
"""
from __future__ import annotations
from pathlib import Path
from collections import deque
import glob
import numpy as np
HERE = Path(__file__).resolve().parent
DUMPS = HERE / "maze_followup"
rng = np.random.default_rng(0)
def is_connected(inp, pred):
g = inp.reshape(30, 30); pr = pred.reshape(30, 30)
se = np.argwhere((g == 3) | (g == 4))
if len(se) < 2:
return True
s, e = tuple(se[0]), tuple(se[1])
ps = set(map(tuple, np.argwhere(pr == 5))) | {s, e}
if any(g[r, c] == 1 for r, c in ps):
return False
seen = {s}; q = deque([s])
while q:
r, c = q.popleft()
if (r, c) == e:
return True
for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
nr, nc = r + dr, c + dc
if 0 <= nr < 30 and 0 <= nc < 30 and (nr, nc) in ps and (nr, nc) not in seen:
seen.add((nr, nc)); q.append((nr, nc))
return False
def auc(score, y):
p, n = score[y == 1], score[y == 0]
if len(p) == 0 or len(n) == 0:
return float("nan")
a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1)
return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n)))
def main():
files = sorted(glob.glob(str(DUMPS / "earlydump_step_*.npz")),
key=lambda f: int(f.split("step_")[1].split("_")[0].split(".")[0]))
if not files:
print("[no earlydump files yet]"); return
pool_ld, pool_ad, pool_conn = [], [], []
print(f"{'ckpt':>10} {'connAcc':>8} {'broken':>7} {'AUC(ldrift)':>11} {'AUC(ansdrift)':>13}")
for f in files:
d = np.load(f)
step = f.split("step_")[1].split(".")[0]
conn = np.array([is_connected(d["inputs"][k], d["preds"][k]) for k in range(len(d["preds"]))]).astype(int)
ld = np.log10(np.clip(d["drift_zH"][:, -4:].mean(1), 1e-12, None))
ad = d["ans_drift_ans"][:, -4:].mean(1).astype(float)
nb = int((conn == 0).sum())
a_ld = auc(-ld, conn) if 3 <= nb <= len(conn) - 3 else float("nan")
a_ad = auc(-ad, conn) if 3 <= nb <= len(conn) - 3 else float("nan")
print(f"{step:>10} {conn.mean():>8.3f} {nb:>7} {a_ld:>11.3f} {a_ad:>13.3f}")
pool_ld.append(ld); pool_ad.append(ad); pool_conn.append(conn)
ld = np.concatenate(pool_ld); ad = np.concatenate(pool_ad); conn = np.concatenate(pool_conn)
nb = int((conn == 0).sum())
a_ld = auc(-ld, conn); a_ad = auc(-ad, conn)
boot = [auc(-ld[i], conn[i]) for i in (rng.integers(0, len(conn), len(conn)) for _ in range(5000))]
boot = [b for b in boot if not np.isnan(b)]
print(f"\nPOOLED: n={len(conn)}, broken={nb}, connectivity-acc={conn.mean():.3f}")
print(f" latent-drift: broken median={np.median(ld[conn==0]):.2f} vs connected={np.median(ld[conn==1]):.2f}")
print(f" AUC(-latent_drift -> connected/complete) = {a_ld:.3f} bootstrap95%CI=[{np.percentile(boot,2.5):.3f}, {np.percentile(boot,97.5):.3f}]")
print(f" AUC(-answer_drift -> connected) = {a_ad:.3f}")
print(f" => robust 'broken/incomplete = more chaotic'? CI excludes 0.5: {np.percentile(boot,2.5) > 0.5}")
if __name__ == "__main__":
main()
|