summaryrefslogtreecommitdiff
path: root/research/flossing/plot_trm_lyap_hist.py
blob: ffa14b26a67597fcb70c51684b519946931b4946 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Per-sample λ_1 histogram: success vs failure on TRM checkpoints."""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def roc_auc_score(y_true, y_score):
    """Simple AUROC: pair-counting."""
    pos = y_score[y_true == 1]
    neg = y_score[y_true == 0]
    if len(pos) == 0 or len(neg) == 0:
        return float("nan")
    # Probability that pos > neg (with 0.5 for ties)
    return float(np.mean(pos[:, None] > neg[None, :]) +
                  0.5 * np.mean(pos[:, None] == neg[None, :]))

ROOT = "/home/yurenh2/rrm/research/flossing"
CKPTS = [
    (26041, 5000),
    (78123, 15000),
    (130205, 25000),
    (182287, 35000),
    (234369, 45000),
]

fig, axes = plt.subplots(2, 3, figsize=(16, 8), sharex=True)
axes = axes.flatten()

bins = np.linspace(-0.2, 0.3, 60)

for ax, (step, epoch) in zip(axes, CKPTS):
    d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz")
    lam1 = d["lyap_spec"][:, 0]
    succ = d["exact_correct"] > 0.5

    # AUROC: predict fail (label=1) from λ_1
    fail_label = (~succ).astype(int)
    auroc = roc_auc_score(fail_label, lam1) if fail_label.sum() > 0 and (1 - fail_label).sum() > 0 else float("nan")

    ax.hist(lam1[succ], bins=bins, alpha=0.55, color="C2", density=False,
            label=f"succ (n={succ.sum()}, μ={lam1[succ].mean():+.3f})")
    ax.hist(lam1[~succ], bins=bins, alpha=0.55, color="C3", density=False,
            label=f"fail (n={(~succ).sum()}, μ={lam1[~succ].mean():+.3f})")
    ax.axvline(0, color="k", lw=0.5, ls=":")
    ax.set_title(f"epoch {epoch:>5} | acc={succ.mean():.3f}  AUROC={auroc:.3f}")
    ax.set_xlabel(r"$\lambda_1$")
    ax.set_ylabel("count")
    ax.legend(fontsize=8, loc="upper left")
    ax.grid(alpha=0.3)

# Hide last subplot
axes[-1].axis("off")

fig.suptitle("TRM per-sample $\\lambda_1$ distribution: success (green) vs failure (red)\n"
              "across training checkpoints. As acc rises, both distributions drift right; gap stays large.",
              fontsize=11)
fig.tight_layout()
out = f"{ROOT}/plots_trm_lyap_hist.png"
fig.savefig(out, dpi=130)
print(f"→ {out}")

# Print AUROC summary
print(f"\n{'epoch':>6} {'acc':>7} {'AUROC':>7}")
for step, epoch in CKPTS:
    d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz")
    lam1 = d["lyap_spec"][:, 0]
    succ = d["exact_correct"] > 0.5
    fail_label = (~succ).astype(int)
    auroc = roc_auc_score(fail_label, lam1) if fail_label.sum() > 0 else float("nan")
    print(f"{epoch:>6} {succ.mean():>7.3f} {auroc:>7.3f}")