diff options
Diffstat (limited to 'research/flossing/analyze_diag.py')
| -rw-r--r-- | research/flossing/analyze_diag.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/research/flossing/analyze_diag.py b/research/flossing/analyze_diag.py new file mode 100644 index 0000000..0382d3c --- /dev/null +++ b/research/flossing/analyze_diag.py @@ -0,0 +1,115 @@ +"""Analyze the HRM diagnostic output: stats + plots.""" +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from scipy import stats +import argparse + +ap = argparse.ArgumentParser() +ap.add_argument("--in", dest="inp", required=True) +ap.add_argument("--out-dir", default="/home/yurenh2/rrm/research/flossing/plots") +args = ap.parse_args() + +import os +os.makedirs(args.out_dir, exist_ok=True) + +d = np.load(args.inp) +print("keys:", list(d.keys())) +lyap = d["lyap_max"] # (N,) +exact = d["exact_correct"].astype(bool) # (N,) +drift_zH = d["drift_zH"] # (N, T) ACT-step level drift +drift_zL = d["drift_zL"] # (N, T) +q_halt = d["q_halt"] # (N, T) +q_continue = d["q_continue"] # (N, T) +halted_at = d["halted_at"] # (N,) — ACT step at which q_halt>q_continue first + +N = len(lyap) +print(f"N={N}, exact_acc={exact.mean():.3f}") + +succ = lyap[exact]; fail = lyap[~exact] +print(f"\n=== Lyapunov max ===") +print(f" success: mean={succ.mean():+.4f} std={succ.std():.4f} n={succ.size}") +print(f" failure: mean={fail.mean():+.4f} std={fail.std():.4f} n={fail.size}") + +# Stats tests +t, p = stats.ttest_ind(succ, fail, equal_var=False) +print(f" Welch t-test: t={t:.3f} p={p:.2e}") +u, p_u = stats.mannwhitneyu(succ, fail, alternative="two-sided") +print(f" Mann-Whitney U: U={u:.0f} p={p_u:.2e}") +# Effect size (Cohen's d) +pooled_std = np.sqrt(((succ.size-1)*succ.var() + (fail.size-1)*fail.var()) / (succ.size+fail.size-2)) +d_cohen = (succ.mean() - fail.mean()) / pooled_std +print(f" Cohen's d: {d_cohen:.3f}") +# AUC: lyap predicts failure +# Failure tends to have higher (less negative) lyap. +auc = stats.mannwhitneyu(fail, succ, alternative="greater").statistic / (fail.size * succ.size) +print(f" AUROC (lyap predicting failure): {auc:.3f}") + +# ----- plot 1: histogram of lyap by success/failure ----- +fig, ax = plt.subplots(1, 1, figsize=(6,4)) +bins = np.linspace(min(lyap.min(),-1.5), max(lyap.max(), -0.1), 40) +ax.hist(succ, bins=bins, alpha=0.55, label=f"success (n={succ.size})", color="C0", density=True) +ax.hist(fail, bins=bins, alpha=0.55, label=f"failure (n={fail.size})", color="C3", density=True) +ax.axvline(succ.mean(), color="C0", ls="--", lw=1) +ax.axvline(fail.mean(), color="C3", ls="--", lw=1) +ax.set_xlabel(r"$\lambda_{max}$ (top Lyapunov exponent, per inner cycle)") +ax.set_ylabel("density") +ax.set_title(f"HRM Sudoku 1k @ step_26040: λ_max distribution\nCohen's d={d_cohen:.2f}, Welch p={p:.1e}, AUROC={auc:.2f}") +ax.legend() +fig.tight_layout() +fig.savefig(f"{args.out_dir}/lyap_hist.png", dpi=130) +plt.close() + +# ----- plot 2: drift trajectory means ----- +fig, axes = plt.subplots(1, 2, figsize=(11,4), sharey=False) +for ax, drift, name in [(axes[0], drift_zH, r"$\|z_H^{(t+1)} - z_H^{(t)}\|$"), + (axes[1], drift_zL, r"$\|z_L^{(t+1)} - z_L^{(t)}\|$")]: + mean_s = drift[exact].mean(0); std_s = drift[exact].std(0) + mean_f = drift[~exact].mean(0); std_f = drift[~exact].std(0) + x = np.arange(drift.shape[1]) + ax.fill_between(x, mean_s-std_s, mean_s+std_s, color="C0", alpha=0.25) + ax.plot(x, mean_s, "C0", label="success") + ax.fill_between(x, mean_f-std_f, mean_f+std_f, color="C3", alpha=0.25) + ax.plot(x, mean_f, "C3", label="failure") + ax.set_xlabel("ACT step") + ax.set_title(f"drift per ACT step: {name}") + ax.legend() + ax.set_yscale("log") +fig.tight_layout() +fig.savefig(f"{args.out_dir}/drift_per_act.png", dpi=130) +plt.close() + +# ----- plot 3: Q-halt gating dynamics ----- +fig, ax = plt.subplots(1, 1, figsize=(6,4)) +diff = q_halt - q_continue +mean_s = diff[exact].mean(0); std_s = diff[exact].std(0) +mean_f = diff[~exact].mean(0); std_f = diff[~exact].std(0) +x = np.arange(diff.shape[1]) +ax.fill_between(x, mean_s-std_s, mean_s+std_s, color="C0", alpha=0.25) +ax.plot(x, mean_s, "C0", label="success") +ax.fill_between(x, mean_f-std_f, mean_f+std_f, color="C3", alpha=0.25) +ax.plot(x, mean_f, "C3", label="failure") +ax.axhline(0, color="k", ls=":", lw=0.6) +ax.set_xlabel("ACT step") +ax.set_ylabel("q_halt − q_continue") +ax.set_title("ACT Q-head gating signal over ACT steps") +ax.legend() +fig.tight_layout() +fig.savefig(f"{args.out_dir}/q_gating.png", dpi=130) +plt.close() + +# ----- plot 4: lyap vs token_acc scatter ----- +token_acc = d["token_acc"] +fig, ax = plt.subplots(1, 1, figsize=(6,4)) +ax.scatter(lyap[exact], token_acc[exact], s=8, alpha=0.3, color="C0", label="exact-correct") +ax.scatter(lyap[~exact], token_acc[~exact], s=8, alpha=0.3, color="C3", label="exact-fail") +ax.set_xlabel(r"$\lambda_{max}$") +ax.set_ylabel("token-level accuracy") +ax.legend() +ax.set_title("λ_max vs token accuracy") +fig.tight_layout() +fig.savefig(f"{args.out_dir}/lyap_vs_tokenacc.png", dpi=130) +plt.close() + +print(f"\nplots saved to {args.out_dir}/") |
