diff options
Diffstat (limited to 'analysis_2x2/analyze_early_connectivity.py')
| -rw-r--r-- | analysis_2x2/analyze_early_connectivity.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/analysis_2x2/analyze_early_connectivity.py b/analysis_2x2/analyze_early_connectivity.py new file mode 100644 index 0000000..c74023a --- /dev/null +++ b/analysis_2x2/analyze_early_connectivity.py @@ -0,0 +1,77 @@ +"""Pooled broken-vs-complete dynamics across EARLY maze checkpoints (large-n robust version). +Reads maze_earlysave dumps (each has preds, inputs, drift_zH, ans_drift, exact via maze_pred_dump). +Connectivity = success criterion. Tests: do BROKEN (incomplete) predictions have higher latent +drift (more chaotic / non-settling) than CONNECTED (complete) ones? Pools across checkpoints for n. +""" +from __future__ import annotations +from pathlib import Path +from collections import deque +import glob +import numpy as np + +HERE = Path(__file__).resolve().parent +DUMPS = HERE / "maze_followup" +rng = np.random.default_rng(0) + + +def is_connected(inp, pred): + g = inp.reshape(30, 30); pr = pred.reshape(30, 30) + se = np.argwhere((g == 3) | (g == 4)) + if len(se) < 2: + return True + s, e = tuple(se[0]), tuple(se[1]) + ps = set(map(tuple, np.argwhere(pr == 5))) | {s, e} + if any(g[r, c] == 1 for r, c in ps): + return False + seen = {s}; q = deque([s]) + while q: + r, c = q.popleft() + if (r, c) == e: + return True + for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]: + nr, nc = r + dr, c + dc + if 0 <= nr < 30 and 0 <= nc < 30 and (nr, nc) in ps and (nr, nc) not in seen: + seen.add((nr, nc)); q.append((nr, nc)) + return False + + +def auc(score, y): + p, n = score[y == 1], score[y == 0] + if len(p) == 0 or len(n) == 0: + return float("nan") + a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1) + return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n))) + + +def main(): + files = sorted(glob.glob(str(DUMPS / "earlydump_step_*.npz")), + key=lambda f: int(f.split("step_")[1].split("_")[0].split(".")[0])) + if not files: + print("[no earlydump files yet]"); return + pool_ld, pool_ad, pool_conn = [], [], [] + print(f"{'ckpt':>10} {'connAcc':>8} {'broken':>7} {'AUC(ldrift)':>11} {'AUC(ansdrift)':>13}") + for f in files: + d = np.load(f) + step = f.split("step_")[1].split(".")[0] + conn = np.array([is_connected(d["inputs"][k], d["preds"][k]) for k in range(len(d["preds"]))]).astype(int) + ld = np.log10(np.clip(d["drift_zH"][:, -4:].mean(1), 1e-12, None)) + ad = d["ans_drift_ans"][:, -4:].mean(1).astype(float) + nb = int((conn == 0).sum()) + a_ld = auc(-ld, conn) if 3 <= nb <= len(conn) - 3 else float("nan") + a_ad = auc(-ad, conn) if 3 <= nb <= len(conn) - 3 else float("nan") + print(f"{step:>10} {conn.mean():>8.3f} {nb:>7} {a_ld:>11.3f} {a_ad:>13.3f}") + pool_ld.append(ld); pool_ad.append(ad); pool_conn.append(conn) + ld = np.concatenate(pool_ld); ad = np.concatenate(pool_ad); conn = np.concatenate(pool_conn) + nb = int((conn == 0).sum()) + a_ld = auc(-ld, conn); a_ad = auc(-ad, conn) + boot = [auc(-ld[i], conn[i]) for i in (rng.integers(0, len(conn), len(conn)) for _ in range(5000))] + boot = [b for b in boot if not np.isnan(b)] + print(f"\nPOOLED: n={len(conn)}, broken={nb}, connectivity-acc={conn.mean():.3f}") + print(f" latent-drift: broken median={np.median(ld[conn==0]):.2f} vs connected={np.median(ld[conn==1]):.2f}") + print(f" AUC(-latent_drift -> connected/complete) = {a_ld:.3f} bootstrap95%CI=[{np.percentile(boot,2.5):.3f}, {np.percentile(boot,97.5):.3f}]") + print(f" AUC(-answer_drift -> connected) = {a_ad:.3f}") + print(f" => robust 'broken/incomplete = more chaotic'? CI excludes 0.5: {np.percentile(boot,2.5) > 0.5}") + + +if __name__ == "__main__": + main() |
