summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_diag.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/analyze_diag.py')
-rw-r--r--research/flossing/analyze_diag.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/research/flossing/analyze_diag.py b/research/flossing/analyze_diag.py
new file mode 100644
index 0000000..0382d3c
--- /dev/null
+++ b/research/flossing/analyze_diag.py
@@ -0,0 +1,115 @@
+"""Analyze the HRM diagnostic output: stats + plots."""
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+from scipy import stats
+import argparse
+
+ap = argparse.ArgumentParser()
+ap.add_argument("--in", dest="inp", required=True)
+ap.add_argument("--out-dir", default="/home/yurenh2/rrm/research/flossing/plots")
+args = ap.parse_args()
+
+import os
+os.makedirs(args.out_dir, exist_ok=True)
+
+d = np.load(args.inp)
+print("keys:", list(d.keys()))
+lyap = d["lyap_max"] # (N,)
+exact = d["exact_correct"].astype(bool) # (N,)
+drift_zH = d["drift_zH"] # (N, T) ACT-step level drift
+drift_zL = d["drift_zL"] # (N, T)
+q_halt = d["q_halt"] # (N, T)
+q_continue = d["q_continue"] # (N, T)
+halted_at = d["halted_at"] # (N,) — ACT step at which q_halt>q_continue first
+
+N = len(lyap)
+print(f"N={N}, exact_acc={exact.mean():.3f}")
+
+succ = lyap[exact]; fail = lyap[~exact]
+print(f"\n=== Lyapunov max ===")
+print(f" success: mean={succ.mean():+.4f} std={succ.std():.4f} n={succ.size}")
+print(f" failure: mean={fail.mean():+.4f} std={fail.std():.4f} n={fail.size}")
+
+# Stats tests
+t, p = stats.ttest_ind(succ, fail, equal_var=False)
+print(f" Welch t-test: t={t:.3f} p={p:.2e}")
+u, p_u = stats.mannwhitneyu(succ, fail, alternative="two-sided")
+print(f" Mann-Whitney U: U={u:.0f} p={p_u:.2e}")
+# Effect size (Cohen's d)
+pooled_std = np.sqrt(((succ.size-1)*succ.var() + (fail.size-1)*fail.var()) / (succ.size+fail.size-2))
+d_cohen = (succ.mean() - fail.mean()) / pooled_std
+print(f" Cohen's d: {d_cohen:.3f}")
+# AUC: lyap predicts failure
+# Failure tends to have higher (less negative) lyap.
+auc = stats.mannwhitneyu(fail, succ, alternative="greater").statistic / (fail.size * succ.size)
+print(f" AUROC (lyap predicting failure): {auc:.3f}")
+
+# ----- plot 1: histogram of lyap by success/failure -----
+fig, ax = plt.subplots(1, 1, figsize=(6,4))
+bins = np.linspace(min(lyap.min(),-1.5), max(lyap.max(), -0.1), 40)
+ax.hist(succ, bins=bins, alpha=0.55, label=f"success (n={succ.size})", color="C0", density=True)
+ax.hist(fail, bins=bins, alpha=0.55, label=f"failure (n={fail.size})", color="C3", density=True)
+ax.axvline(succ.mean(), color="C0", ls="--", lw=1)
+ax.axvline(fail.mean(), color="C3", ls="--", lw=1)
+ax.set_xlabel(r"$\lambda_{max}$ (top Lyapunov exponent, per inner cycle)")
+ax.set_ylabel("density")
+ax.set_title(f"HRM Sudoku 1k @ step_26040: λ_max distribution\nCohen's d={d_cohen:.2f}, Welch p={p:.1e}, AUROC={auc:.2f}")
+ax.legend()
+fig.tight_layout()
+fig.savefig(f"{args.out_dir}/lyap_hist.png", dpi=130)
+plt.close()
+
+# ----- plot 2: drift trajectory means -----
+fig, axes = plt.subplots(1, 2, figsize=(11,4), sharey=False)
+for ax, drift, name in [(axes[0], drift_zH, r"$\|z_H^{(t+1)} - z_H^{(t)}\|$"),
+ (axes[1], drift_zL, r"$\|z_L^{(t+1)} - z_L^{(t)}\|$")]:
+ mean_s = drift[exact].mean(0); std_s = drift[exact].std(0)
+ mean_f = drift[~exact].mean(0); std_f = drift[~exact].std(0)
+ x = np.arange(drift.shape[1])
+ ax.fill_between(x, mean_s-std_s, mean_s+std_s, color="C0", alpha=0.25)
+ ax.plot(x, mean_s, "C0", label="success")
+ ax.fill_between(x, mean_f-std_f, mean_f+std_f, color="C3", alpha=0.25)
+ ax.plot(x, mean_f, "C3", label="failure")
+ ax.set_xlabel("ACT step")
+ ax.set_title(f"drift per ACT step: {name}")
+ ax.legend()
+ ax.set_yscale("log")
+fig.tight_layout()
+fig.savefig(f"{args.out_dir}/drift_per_act.png", dpi=130)
+plt.close()
+
+# ----- plot 3: Q-halt gating dynamics -----
+fig, ax = plt.subplots(1, 1, figsize=(6,4))
+diff = q_halt - q_continue
+mean_s = diff[exact].mean(0); std_s = diff[exact].std(0)
+mean_f = diff[~exact].mean(0); std_f = diff[~exact].std(0)
+x = np.arange(diff.shape[1])
+ax.fill_between(x, mean_s-std_s, mean_s+std_s, color="C0", alpha=0.25)
+ax.plot(x, mean_s, "C0", label="success")
+ax.fill_between(x, mean_f-std_f, mean_f+std_f, color="C3", alpha=0.25)
+ax.plot(x, mean_f, "C3", label="failure")
+ax.axhline(0, color="k", ls=":", lw=0.6)
+ax.set_xlabel("ACT step")
+ax.set_ylabel("q_halt − q_continue")
+ax.set_title("ACT Q-head gating signal over ACT steps")
+ax.legend()
+fig.tight_layout()
+fig.savefig(f"{args.out_dir}/q_gating.png", dpi=130)
+plt.close()
+
+# ----- plot 4: lyap vs token_acc scatter -----
+token_acc = d["token_acc"]
+fig, ax = plt.subplots(1, 1, figsize=(6,4))
+ax.scatter(lyap[exact], token_acc[exact], s=8, alpha=0.3, color="C0", label="exact-correct")
+ax.scatter(lyap[~exact], token_acc[~exact], s=8, alpha=0.3, color="C3", label="exact-fail")
+ax.set_xlabel(r"$\lambda_{max}$")
+ax.set_ylabel("token-level accuracy")
+ax.legend()
+ax.set_title("λ_max vs token accuracy")
+fig.tight_layout()
+fig.savefig(f"{args.out_dir}/lyap_vs_tokenacc.png", dpi=130)
+plt.close()
+
+print(f"\nplots saved to {args.out_dir}/")