"""Pooled broken-vs-complete dynamics across EARLY maze checkpoints (large-n robust version). Reads maze_earlysave dumps (each has preds, inputs, drift_zH, ans_drift, exact via maze_pred_dump). Connectivity = success criterion. Tests: do BROKEN (incomplete) predictions have higher latent drift (more chaotic / non-settling) than CONNECTED (complete) ones? Pools across checkpoints for n. """ from __future__ import annotations from pathlib import Path from collections import deque import glob import numpy as np HERE = Path(__file__).resolve().parent DUMPS = HERE / "maze_followup" rng = np.random.default_rng(0) def is_connected(inp, pred): g = inp.reshape(30, 30); pr = pred.reshape(30, 30) se = np.argwhere((g == 3) | (g == 4)) if len(se) < 2: return True s, e = tuple(se[0]), tuple(se[1]) ps = set(map(tuple, np.argwhere(pr == 5))) | {s, e} if any(g[r, c] == 1 for r, c in ps): return False seen = {s}; q = deque([s]) while q: r, c = q.popleft() if (r, c) == e: return True for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]: nr, nc = r + dr, c + dc if 0 <= nr < 30 and 0 <= nc < 30 and (nr, nc) in ps and (nr, nc) not in seen: seen.add((nr, nc)); q.append((nr, nc)) return False def auc(score, y): p, n = score[y == 1], score[y == 0] if len(p) == 0 or len(n) == 0: return float("nan") a = np.concatenate([p, n]); o = np.argsort(a); r = np.empty(len(a)); r[o] = np.arange(1, len(a) + 1) return float((r[:len(p)].sum() - len(p) * (len(p) + 1) / 2) / (len(p) * len(n))) def main(): files = sorted(glob.glob(str(DUMPS / "earlydump_step_*.npz")), key=lambda f: int(f.split("step_")[1].split("_")[0].split(".")[0])) if not files: print("[no earlydump files yet]"); return pool_ld, pool_ad, pool_conn = [], [], [] print(f"{'ckpt':>10} {'connAcc':>8} {'broken':>7} {'AUC(ldrift)':>11} {'AUC(ansdrift)':>13}") for f in files: d = np.load(f) step = f.split("step_")[1].split(".")[0] conn = np.array([is_connected(d["inputs"][k], d["preds"][k]) for k in range(len(d["preds"]))]).astype(int) ld = np.log10(np.clip(d["drift_zH"][:, -4:].mean(1), 1e-12, None)) ad = d["ans_drift_ans"][:, -4:].mean(1).astype(float) nb = int((conn == 0).sum()) a_ld = auc(-ld, conn) if 3 <= nb <= len(conn) - 3 else float("nan") a_ad = auc(-ad, conn) if 3 <= nb <= len(conn) - 3 else float("nan") print(f"{step:>10} {conn.mean():>8.3f} {nb:>7} {a_ld:>11.3f} {a_ad:>13.3f}") pool_ld.append(ld); pool_ad.append(ad); pool_conn.append(conn) ld = np.concatenate(pool_ld); ad = np.concatenate(pool_ad); conn = np.concatenate(pool_conn) nb = int((conn == 0).sum()) a_ld = auc(-ld, conn); a_ad = auc(-ad, conn) boot = [auc(-ld[i], conn[i]) for i in (rng.integers(0, len(conn), len(conn)) for _ in range(5000))] boot = [b for b in boot if not np.isnan(b)] print(f"\nPOOLED: n={len(conn)}, broken={nb}, connectivity-acc={conn.mean():.3f}") print(f" latent-drift: broken median={np.median(ld[conn==0]):.2f} vs connected={np.median(ld[conn==1]):.2f}") print(f" AUC(-latent_drift -> connected/complete) = {a_ld:.3f} bootstrap95%CI=[{np.percentile(boot,2.5):.3f}, {np.percentile(boot,97.5):.3f}]") print(f" AUC(-answer_drift -> connected) = {a_ad:.3f}") print(f" => robust 'broken/incomplete = more chaotic'? CI excludes 0.5: {np.percentile(boot,2.5) > 0.5}") if __name__ == "__main__": main()