summaryrefslogtreecommitdiff
path: root/research/flossing/plot_ckpt_evolution.py
blob: 1bde7b4836b2e9cecf010ad39e88ae2c4bc07fa6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Plot λ-spectrum and accuracy evolution across training checkpoints."""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

CKPT_STEPS = [2604, 7812, 13020, 18228, 20832, 26040]
CKPT_EPOCHS = [int(s * 20000 / 26041) for s in CKPT_STEPS]  # actual epoch (rounded)
ROOT = "/home/yurenh2/rrm/research/flossing/ckpt_evolution"
OUT_DIR = "/home/yurenh2/rrm/research/flossing/plots_evolution"
import os; os.makedirs(OUT_DIR, exist_ok=True)

datas = {}
for s in CKPT_STEPS:
    d = np.load(f"{ROOT}/step_{s}.npz")
    datas[s] = d

# --- plot 1: λ_1 vs training step, split by success/failure ---
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

ax = axes[0]
lam1_s_mean, lam1_f_mean = [], []
lam1_s_std, lam1_f_std = [], []
mean_s, mean_f = [], []
mean_s_std, mean_f_std = [], []
accs = []
for s in CKPT_STEPS:
    d = datas[s]
    succ = d["exact_correct"] > 0.5
    L = d["lyap_spec"]
    lam1_s_mean.append(L[succ,0].mean()); lam1_s_std.append(L[succ,0].std())
    lam1_f_mean.append(L[~succ,0].mean()); lam1_f_std.append(L[~succ,0].std())
    mean_s.append(L[succ].mean(axis=1).mean()); mean_s_std.append(L[succ].mean(axis=1).std())
    mean_f.append(L[~succ].mean(axis=1).mean()); mean_f_std.append(L[~succ].mean(axis=1).std())
    accs.append(succ.mean())

x = CKPT_EPOCHS
ax.errorbar(x, lam1_s_mean, yerr=lam1_s_std, fmt="C0-o", capsize=4, label=r"$\lambda_1$ success")
ax.errorbar(x, lam1_f_mean, yerr=lam1_f_std, fmt="C3-o", capsize=4, label=r"$\lambda_1$ failure")
ax.errorbar(x, mean_s, yerr=mean_s_std, fmt="C0--s", alpha=0.6, capsize=3, label=r"mean$_{1:8}$ success")
ax.errorbar(x, mean_f, yerr=mean_f_std, fmt="C3--s", alpha=0.6, capsize=3, label=r"mean$_{1:8}$ failure")
ax.axhline(0, color="k", ls=":", lw=0.6)
ax.set_xlabel("training epoch")
ax.set_ylabel(r"Lyapunov exponent (per inner cycle)")
ax.set_title("λ evolution during training\n(error bars = within-group std)")
ax.legend(loc="best", fontsize=9)
ax.grid(alpha=0.3)

ax2 = axes[1]
gap = np.array(lam1_f_mean) - np.array(lam1_s_mean)
ax2.plot(x, gap, "C2-o", label=r"$\lambda_1^{fail} - \lambda_1^{succ}$", lw=2)
ax2.axhline(0, color="k", ls=":", lw=0.6)
ax2b = ax2.twinx()
ax2b.plot(x, accs, "C4-^", label="exact accuracy")
ax2b.set_ylabel("test exact accuracy", color="C4")
ax2b.tick_params(axis="y", labelcolor="C4")
ax2.set_xlabel("training epoch")
ax2.set_ylabel("λ gap (fail − succ)", color="C2")
ax2.tick_params(axis="y", labelcolor="C2")
ax2.set_title("succ/fail λ gap vs accuracy")
ax2.legend(loc="upper left")
ax2b.legend(loc="lower right")
ax2.grid(alpha=0.3)

fig.tight_layout()
fig.savefig(f"{OUT_DIR}/lyap_vs_training.png", dpi=130)
plt.close()

# --- plot 2: per-checkpoint λ spectrum (all 8) shape ---
fig, axes = plt.subplots(2, 3, figsize=(15, 7), sharey=True)
for ax, s in zip(axes.flat, CKPT_STEPS):
    d = datas[s]
    succ = d["exact_correct"] > 0.5
    L = d["lyap_spec"]
    idx = np.arange(1, L.shape[1]+1)
    mean_s_arr = L[succ].mean(0); std_s_arr = L[succ].std(0)
    mean_f_arr = L[~succ].mean(0); std_f_arr = L[~succ].std(0)
    ax.fill_between(idx, mean_s_arr-std_s_arr, mean_s_arr+std_s_arr, color="C0", alpha=0.2)
    ax.plot(idx, mean_s_arr, "C0-o", label=f"succ n={succ.sum()}")
    ax.fill_between(idx, mean_f_arr-std_f_arr, mean_f_arr+std_f_arr, color="C3", alpha=0.2)
    ax.plot(idx, mean_f_arr, "C3-o", label=f"fail n={(~succ).sum()}")
    ax.axhline(0, color="k", ls=":", lw=0.6)
    ax.set_title(f"step_{s} (epoch≈{int(s*20000/26041)}, acc={succ.mean():.2%})")
    ax.set_xlabel("λ index")
    ax.legend(fontsize=8, loc="best")
    ax.grid(alpha=0.3)
axes[0,0].set_ylabel(r"$\lambda_i$")
axes[1,0].set_ylabel(r"$\lambda_i$")
fig.tight_layout()
fig.savefig(f"{OUT_DIR}/spectrum_per_ckpt.png", dpi=130)
plt.close()

# --- summary table ---
print(f"{'epoch':>6} {'step':>6} {'acc':>7} {'λ_1(s)':>9} {'λ_1(f)':>9} {'gap':>7} "
      f"{'mean(s)':>9} {'mean(f)':>9} {'mean_gap':>9}")
for s, ep in zip(CKPT_STEPS, CKPT_EPOCHS):
    d = datas[s]
    succ = d["exact_correct"] > 0.5
    L = d["lyap_spec"]
    l1s, l1f = L[succ,0].mean(), L[~succ,0].mean()
    ms, mf = L[succ].mean(), L[~succ].mean()
    print(f"{ep:>6} {s:>6} {succ.mean():>7.2%} {l1s:+9.3f} {l1f:+9.3f} {l1f-l1s:+7.3f} "
          f"{ms:+9.3f} {mf:+9.3f} {mf-ms:+9.3f}")

print(f"\nplots → {OUT_DIR}/")