summaryrefslogtreecommitdiff
path: root/research/flossing/plot_trm_chaos_trajectory.py
blob: da909c1fd3414e8ce595e298b780b353ef64a13a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Plot TRM Lyapunov trajectory across training checkpoints."""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

ROOT = "/home/yurenh2/rrm/research/flossing"
CKPTS = [
    (26041, 5000),
    (52082, 10000),
    (78123, 15000),
    (104164, 20000),
    (130205, 25000),
    (156246, 30000),
    (182287, 35000),
    (208328, 40000),
    (234369, 45000),
    (260410, 50000),
]

acc_list, succ_lam, fail_lam, dlam = [], [], [], []
for step, epoch in CKPTS:
    d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz")
    succ = d["exact_correct"] > 0.5
    lam_1 = d["lyap_spec"][:, 0]
    acc_list.append(succ.mean())
    succ_lam.append(lam_1[succ].mean())
    fail_lam.append(lam_1[~succ].mean())
    dlam.append(lam_1[~succ].mean() - lam_1[succ].mean())

epochs = [e for _, e in CKPTS]

fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))

ax = axes[0]
ax.plot(epochs, acc_list, "ko-", lw=2, ms=8)
ax.set_xlabel("epoch"); ax.set_ylabel("exact_acc on 512 test")
ax.set_title("TRM accuracy vs training")
ax.grid(alpha=0.3)
for e, a in zip(epochs, acc_list):
    ax.text(e, a + 0.01, f"{a:.3f}", ha="center", fontsize=9)

ax = axes[1]
ax.plot(epochs, succ_lam, "C2o-", lw=2, ms=8, label="succ λ_1")
ax.plot(epochs, fail_lam, "C3o-", lw=2, ms=8, label="fail λ_1")
ax.axhline(0, color="k", lw=0.5, ls=":")
ax.set_xlabel("epoch"); ax.set_ylabel(r"$\lambda_1$ (top joint Lyap)")
ax.set_title("TRM Lyapunov drift: succ → criticality, fail → chaos")
ax.legend(); ax.grid(alpha=0.3)

ax = axes[2]
ax.plot(epochs, dlam, "C0o-", lw=2, ms=8)
ax.set_xlabel("epoch"); ax.set_ylabel(r"$\Delta\lambda$ = fail - succ")
ax.set_title("Discrimination gap (non-monotonic)")
ax.grid(alpha=0.3)
for e, d in zip(epochs, dlam):
    ax.text(e, d + 0.003, f"{d:+.3f}", ha="center", fontsize=9)

fig.suptitle("TRM chaos onset trajectory (single-GPU, epoch 5K-25K, 5 checkpoints)", fontsize=12)
fig.tight_layout()
out = f"{ROOT}/plots_trm_chaos_onset.png"
fig.savefig(out, dpi=130)
print(f"→ {out}")

print(f"\n{'epoch':>6} {'acc':>7} {'succ_λ':>9} {'fail_λ':>9} {'Δλ':>8}")
for e, a, sl, fl, dl in zip(epochs, acc_list, succ_lam, fail_lam, dlam):
    print(f"{e:>6} {a:>7.3f} {sl:>+9.4f} {fl:>+9.4f} {dl:>+8.4f}")