diff options
Diffstat (limited to 'research/flossing/plot_trm_lyap_hist.py')
| -rw-r--r-- | research/flossing/plot_trm_lyap_hist.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/research/flossing/plot_trm_lyap_hist.py b/research/flossing/plot_trm_lyap_hist.py new file mode 100644 index 0000000..ffa14b2 --- /dev/null +++ b/research/flossing/plot_trm_lyap_hist.py @@ -0,0 +1,69 @@ +"""Per-sample λ_1 histogram: success vs failure on TRM checkpoints.""" +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +def roc_auc_score(y_true, y_score): + """Simple AUROC: pair-counting.""" + pos = y_score[y_true == 1] + neg = y_score[y_true == 0] + if len(pos) == 0 or len(neg) == 0: + return float("nan") + # Probability that pos > neg (with 0.5 for ties) + return float(np.mean(pos[:, None] > neg[None, :]) + + 0.5 * np.mean(pos[:, None] == neg[None, :])) + +ROOT = "/home/yurenh2/rrm/research/flossing" +CKPTS = [ + (26041, 5000), + (78123, 15000), + (130205, 25000), + (182287, 35000), + (234369, 45000), +] + +fig, axes = plt.subplots(2, 3, figsize=(16, 8), sharex=True) +axes = axes.flatten() + +bins = np.linspace(-0.2, 0.3, 60) + +for ax, (step, epoch) in zip(axes, CKPTS): + d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz") + lam1 = d["lyap_spec"][:, 0] + succ = d["exact_correct"] > 0.5 + + # AUROC: predict fail (label=1) from λ_1 + fail_label = (~succ).astype(int) + auroc = roc_auc_score(fail_label, lam1) if fail_label.sum() > 0 and (1 - fail_label).sum() > 0 else float("nan") + + ax.hist(lam1[succ], bins=bins, alpha=0.55, color="C2", density=False, + label=f"succ (n={succ.sum()}, μ={lam1[succ].mean():+.3f})") + ax.hist(lam1[~succ], bins=bins, alpha=0.55, color="C3", density=False, + label=f"fail (n={(~succ).sum()}, μ={lam1[~succ].mean():+.3f})") + ax.axvline(0, color="k", lw=0.5, ls=":") + ax.set_title(f"epoch {epoch:>5} | acc={succ.mean():.3f} AUROC={auroc:.3f}") + ax.set_xlabel(r"$\lambda_1$") + ax.set_ylabel("count") + ax.legend(fontsize=8, loc="upper left") + ax.grid(alpha=0.3) + +# Hide last subplot +axes[-1].axis("off") + +fig.suptitle("TRM per-sample $\\lambda_1$ distribution: success (green) vs failure (red)\n" + "across training checkpoints. As acc rises, both distributions drift right; gap stays large.", + fontsize=11) +fig.tight_layout() +out = f"{ROOT}/plots_trm_lyap_hist.png" +fig.savefig(out, dpi=130) +print(f"→ {out}") + +# Print AUROC summary +print(f"\n{'epoch':>6} {'acc':>7} {'AUROC':>7}") +for step, epoch in CKPTS: + d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz") + lam1 = d["lyap_spec"][:, 0] + succ = d["exact_correct"] > 0.5 + fail_label = (~succ).astype(int) + auroc = roc_auc_score(fail_label, lam1) if fail_label.sum() > 0 else float("nan") + print(f"{epoch:>6} {succ.mean():>7.3f} {auroc:>7.3f}") |
