summaryrefslogtreecommitdiff
path: root/analysis_2x2/analyze_early_connectivity.py
diff options
context:
space:
mode:
Diffstat (limited to 'analysis_2x2/analyze_early_connectivity.py')
-rw-r--r--analysis_2x2/analyze_early_connectivity.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/analysis_2x2/analyze_early_connectivity.py b/analysis_2x2/analyze_early_connectivity.py
new file mode 100644
index 0000000..c74023a
--- /dev/null
+++ b/analysis_2x2/analyze_early_connectivity.py
@@ -0,0 +1,77 @@
+"""Pooled broken-vs-complete dynamics across EARLY maze checkpoints (large-n robust version).
+Reads maze_earlysave dumps (each has preds, inputs, drift_zH, ans_drift, exact via maze_pred_dump).
+Connectivity = success criterion. Tests: do BROKEN (incomplete) predictions have higher latent
+drift (more chaotic / non-settling) than CONNECTED (complete) ones? Pools across checkpoints for n.
+"""
+from __future__ import annotations
+from pathlib import Path
+from collections import deque
+import glob
+import numpy as np
+
+HERE = Path(__file__).resolve().parent
+DUMPS = HERE / "maze_followup"
+rng = np.random.default_rng(0)
+
+
+def is_connected(inp, pred):
+ g = inp.reshape(30, 30); pr = pred.reshape(30, 30)
+ se = np.argwhere((g == 3) | (g == 4))
+ if len(se) < 2:
+ return True
+ s, e = tuple(se[0]), tuple(se[1])
+ ps = set(map(tuple, np.argwhere(pr == 5))) | {s, e}
+ if any(g[r, c] == 1 for r, c in ps):
+ return False
+ seen = {s}; q = deque([s])
+ while q:
+ r, c = q.popleft()
+ if (r, c) == e:
+ return True
+ for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
+ nr, nc = r + dr, c + dc
+ if 0 <= nr < 30 and 0 <= nc < 30 and (nr, nc) in ps and (nr, nc) not in seen:
+ seen.add((nr, nc)); q.append((nr, nc))
+ return False
+
+
+def auc(score, y):
+ p, n = score[y == 1], score[y == 0]
+ if len(p) == 0 or len(n) == 0:
+ return float("nan")
+ a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1)
+ return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n)))
+
+
+def main():
+ files = sorted(glob.glob(str(DUMPS / "earlydump_step_*.npz")),
+ key=lambda f: int(f.split("step_")[1].split("_")[0].split(".")[0]))
+ if not files:
+ print("[no earlydump files yet]"); return
+ pool_ld, pool_ad, pool_conn = [], [], []
+ print(f"{'ckpt':>10} {'connAcc':>8} {'broken':>7} {'AUC(ldrift)':>11} {'AUC(ansdrift)':>13}")
+ for f in files:
+ d = np.load(f)
+ step = f.split("step_")[1].split(".")[0]
+ conn = np.array([is_connected(d["inputs"][k], d["preds"][k]) for k in range(len(d["preds"]))]).astype(int)
+ ld = np.log10(np.clip(d["drift_zH"][:, -4:].mean(1), 1e-12, None))
+ ad = d["ans_drift_ans"][:, -4:].mean(1).astype(float)
+ nb = int((conn == 0).sum())
+ a_ld = auc(-ld, conn) if 3 <= nb <= len(conn) - 3 else float("nan")
+ a_ad = auc(-ad, conn) if 3 <= nb <= len(conn) - 3 else float("nan")
+ print(f"{step:>10} {conn.mean():>8.3f} {nb:>7} {a_ld:>11.3f} {a_ad:>13.3f}")
+ pool_ld.append(ld); pool_ad.append(ad); pool_conn.append(conn)
+ ld = np.concatenate(pool_ld); ad = np.concatenate(pool_ad); conn = np.concatenate(pool_conn)
+ nb = int((conn == 0).sum())
+ a_ld = auc(-ld, conn); a_ad = auc(-ad, conn)
+ boot = [auc(-ld[i], conn[i]) for i in (rng.integers(0, len(conn), len(conn)) for _ in range(5000))]
+ boot = [b for b in boot if not np.isnan(b)]
+ print(f"\nPOOLED: n={len(conn)}, broken={nb}, connectivity-acc={conn.mean():.3f}")
+ print(f" latent-drift: broken median={np.median(ld[conn==0]):.2f} vs connected={np.median(ld[conn==1]):.2f}")
+ print(f" AUC(-latent_drift -> connected/complete) = {a_ld:.3f} bootstrap95%CI=[{np.percentile(boot,2.5):.3f}, {np.percentile(boot,97.5):.3f}]")
+ print(f" AUC(-answer_drift -> connected) = {a_ad:.3f}")
+ print(f" => robust 'broken/incomplete = more chaotic'? CI excludes 0.5: {np.percentile(boot,2.5) > 0.5}")
+
+
+if __name__ == "__main__":
+ main()