"""Per-sample λ_1 histogram: success vs failure on TRM checkpoints.""" import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt def roc_auc_score(y_true, y_score): """Simple AUROC: pair-counting.""" pos = y_score[y_true == 1] neg = y_score[y_true == 0] if len(pos) == 0 or len(neg) == 0: return float("nan") # Probability that pos > neg (with 0.5 for ties) return float(np.mean(pos[:, None] > neg[None, :]) + 0.5 * np.mean(pos[:, None] == neg[None, :])) ROOT = "/home/yurenh2/rrm/research/flossing" CKPTS = [ (26041, 5000), (78123, 15000), (130205, 25000), (182287, 35000), (234369, 45000), ] fig, axes = plt.subplots(2, 3, figsize=(16, 8), sharex=True) axes = axes.flatten() bins = np.linspace(-0.2, 0.3, 60) for ax, (step, epoch) in zip(axes, CKPTS): d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz") lam1 = d["lyap_spec"][:, 0] succ = d["exact_correct"] > 0.5 # AUROC: predict fail (label=1) from λ_1 fail_label = (~succ).astype(int) auroc = roc_auc_score(fail_label, lam1) if fail_label.sum() > 0 and (1 - fail_label).sum() > 0 else float("nan") ax.hist(lam1[succ], bins=bins, alpha=0.55, color="C2", density=False, label=f"succ (n={succ.sum()}, μ={lam1[succ].mean():+.3f})") ax.hist(lam1[~succ], bins=bins, alpha=0.55, color="C3", density=False, label=f"fail (n={(~succ).sum()}, μ={lam1[~succ].mean():+.3f})") ax.axvline(0, color="k", lw=0.5, ls=":") ax.set_title(f"epoch {epoch:>5} | acc={succ.mean():.3f} AUROC={auroc:.3f}") ax.set_xlabel(r"$\lambda_1$") ax.set_ylabel("count") ax.legend(fontsize=8, loc="upper left") ax.grid(alpha=0.3) # Hide last subplot axes[-1].axis("off") fig.suptitle("TRM per-sample $\\lambda_1$ distribution: success (green) vs failure (red)\n" "across training checkpoints. As acc rises, both distributions drift right; gap stays large.", fontsize=11) fig.tight_layout() out = f"{ROOT}/plots_trm_lyap_hist.png" fig.savefig(out, dpi=130) print(f"→ {out}") # Print AUROC summary print(f"\n{'epoch':>6} {'acc':>7} {'AUROC':>7}") for step, epoch in CKPTS: d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz") lam1 = d["lyap_spec"][:, 0] succ = d["exact_correct"] > 0.5 fail_label = (~succ).astype(int) auroc = roc_auc_score(fail_label, lam1) if fail_label.sum() > 0 else float("nan") print(f"{epoch:>6} {succ.mean():>7.3f} {auroc:>7.3f}")