summaryrefslogtreecommitdiff
path: root/research/flossing/plot_trm_lyap_hist.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/plot_trm_lyap_hist.py')
-rw-r--r--research/flossing/plot_trm_lyap_hist.py69
1 files changed, 69 insertions, 0 deletions
diff --git a/research/flossing/plot_trm_lyap_hist.py b/research/flossing/plot_trm_lyap_hist.py
new file mode 100644
index 0000000..ffa14b2
--- /dev/null
+++ b/research/flossing/plot_trm_lyap_hist.py
@@ -0,0 +1,69 @@
+"""Per-sample λ_1 histogram: success vs failure on TRM checkpoints."""
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+def roc_auc_score(y_true, y_score):
+ """Simple AUROC: pair-counting."""
+ pos = y_score[y_true == 1]
+ neg = y_score[y_true == 0]
+ if len(pos) == 0 or len(neg) == 0:
+ return float("nan")
+ # Probability that pos > neg (with 0.5 for ties)
+ return float(np.mean(pos[:, None] > neg[None, :]) +
+ 0.5 * np.mean(pos[:, None] == neg[None, :]))
+
+ROOT = "/home/yurenh2/rrm/research/flossing"
+CKPTS = [
+ (26041, 5000),
+ (78123, 15000),
+ (130205, 25000),
+ (182287, 35000),
+ (234369, 45000),
+]
+
+fig, axes = plt.subplots(2, 3, figsize=(16, 8), sharex=True)
+axes = axes.flatten()
+
+bins = np.linspace(-0.2, 0.3, 60)
+
+for ax, (step, epoch) in zip(axes, CKPTS):
+ d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz")
+ lam1 = d["lyap_spec"][:, 0]
+ succ = d["exact_correct"] > 0.5
+
+ # AUROC: predict fail (label=1) from λ_1
+ fail_label = (~succ).astype(int)
+ auroc = roc_auc_score(fail_label, lam1) if fail_label.sum() > 0 and (1 - fail_label).sum() > 0 else float("nan")
+
+ ax.hist(lam1[succ], bins=bins, alpha=0.55, color="C2", density=False,
+ label=f"succ (n={succ.sum()}, μ={lam1[succ].mean():+.3f})")
+ ax.hist(lam1[~succ], bins=bins, alpha=0.55, color="C3", density=False,
+ label=f"fail (n={(~succ).sum()}, μ={lam1[~succ].mean():+.3f})")
+ ax.axvline(0, color="k", lw=0.5, ls=":")
+ ax.set_title(f"epoch {epoch:>5} | acc={succ.mean():.3f} AUROC={auroc:.3f}")
+ ax.set_xlabel(r"$\lambda_1$")
+ ax.set_ylabel("count")
+ ax.legend(fontsize=8, loc="upper left")
+ ax.grid(alpha=0.3)
+
+# Hide last subplot
+axes[-1].axis("off")
+
+fig.suptitle("TRM per-sample $\\lambda_1$ distribution: success (green) vs failure (red)\n"
+ "across training checkpoints. As acc rises, both distributions drift right; gap stays large.",
+ fontsize=11)
+fig.tight_layout()
+out = f"{ROOT}/plots_trm_lyap_hist.png"
+fig.savefig(out, dpi=130)
+print(f"→ {out}")
+
+# Print AUROC summary
+print(f"\n{'epoch':>6} {'acc':>7} {'AUROC':>7}")
+for step, epoch in CKPTS:
+ d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz")
+ lam1 = d["lyap_spec"][:, 0]
+ succ = d["exact_correct"] > 0.5
+ fail_label = (~succ).astype(int)
+ auroc = roc_auc_score(fail_label, lam1) if fail_label.sum() > 0 else float("nan")
+ print(f"{epoch:>6} {succ.mean():>7.3f} {auroc:>7.3f}")