summaryrefslogtreecommitdiff
path: root/research/flossing/plot_ckpt_evolution.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/plot_ckpt_evolution.py')
-rw-r--r--research/flossing/plot_ckpt_evolution.py105
1 files changed, 105 insertions, 0 deletions
diff --git a/research/flossing/plot_ckpt_evolution.py b/research/flossing/plot_ckpt_evolution.py
new file mode 100644
index 0000000..1bde7b4
--- /dev/null
+++ b/research/flossing/plot_ckpt_evolution.py
@@ -0,0 +1,105 @@
+"""Plot λ-spectrum and accuracy evolution across training checkpoints."""
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+
+CKPT_STEPS = [2604, 7812, 13020, 18228, 20832, 26040]
+CKPT_EPOCHS = [int(s * 20000 / 26041) for s in CKPT_STEPS] # actual epoch (rounded)
+ROOT = "/home/yurenh2/rrm/research/flossing/ckpt_evolution"
+OUT_DIR = "/home/yurenh2/rrm/research/flossing/plots_evolution"
+import os; os.makedirs(OUT_DIR, exist_ok=True)
+
+datas = {}
+for s in CKPT_STEPS:
+ d = np.load(f"{ROOT}/step_{s}.npz")
+ datas[s] = d
+
+# --- plot 1: λ_1 vs training step, split by success/failure ---
+fig, axes = plt.subplots(1, 2, figsize=(13, 5))
+
+ax = axes[0]
+lam1_s_mean, lam1_f_mean = [], []
+lam1_s_std, lam1_f_std = [], []
+mean_s, mean_f = [], []
+mean_s_std, mean_f_std = [], []
+accs = []
+for s in CKPT_STEPS:
+ d = datas[s]
+ succ = d["exact_correct"] > 0.5
+ L = d["lyap_spec"]
+ lam1_s_mean.append(L[succ,0].mean()); lam1_s_std.append(L[succ,0].std())
+ lam1_f_mean.append(L[~succ,0].mean()); lam1_f_std.append(L[~succ,0].std())
+ mean_s.append(L[succ].mean(axis=1).mean()); mean_s_std.append(L[succ].mean(axis=1).std())
+ mean_f.append(L[~succ].mean(axis=1).mean()); mean_f_std.append(L[~succ].mean(axis=1).std())
+ accs.append(succ.mean())
+
+x = CKPT_EPOCHS
+ax.errorbar(x, lam1_s_mean, yerr=lam1_s_std, fmt="C0-o", capsize=4, label=r"$\lambda_1$ success")
+ax.errorbar(x, lam1_f_mean, yerr=lam1_f_std, fmt="C3-o", capsize=4, label=r"$\lambda_1$ failure")
+ax.errorbar(x, mean_s, yerr=mean_s_std, fmt="C0--s", alpha=0.6, capsize=3, label=r"mean$_{1:8}$ success")
+ax.errorbar(x, mean_f, yerr=mean_f_std, fmt="C3--s", alpha=0.6, capsize=3, label=r"mean$_{1:8}$ failure")
+ax.axhline(0, color="k", ls=":", lw=0.6)
+ax.set_xlabel("training epoch")
+ax.set_ylabel(r"Lyapunov exponent (per inner cycle)")
+ax.set_title("λ evolution during training\n(error bars = within-group std)")
+ax.legend(loc="best", fontsize=9)
+ax.grid(alpha=0.3)
+
+ax2 = axes[1]
+gap = np.array(lam1_f_mean) - np.array(lam1_s_mean)
+ax2.plot(x, gap, "C2-o", label=r"$\lambda_1^{fail} - \lambda_1^{succ}$", lw=2)
+ax2.axhline(0, color="k", ls=":", lw=0.6)
+ax2b = ax2.twinx()
+ax2b.plot(x, accs, "C4-^", label="exact accuracy")
+ax2b.set_ylabel("test exact accuracy", color="C4")
+ax2b.tick_params(axis="y", labelcolor="C4")
+ax2.set_xlabel("training epoch")
+ax2.set_ylabel("λ gap (fail − succ)", color="C2")
+ax2.tick_params(axis="y", labelcolor="C2")
+ax2.set_title("succ/fail λ gap vs accuracy")
+ax2.legend(loc="upper left")
+ax2b.legend(loc="lower right")
+ax2.grid(alpha=0.3)
+
+fig.tight_layout()
+fig.savefig(f"{OUT_DIR}/lyap_vs_training.png", dpi=130)
+plt.close()
+
+# --- plot 2: per-checkpoint λ spectrum (all 8) shape ---
+fig, axes = plt.subplots(2, 3, figsize=(15, 7), sharey=True)
+for ax, s in zip(axes.flat, CKPT_STEPS):
+ d = datas[s]
+ succ = d["exact_correct"] > 0.5
+ L = d["lyap_spec"]
+ idx = np.arange(1, L.shape[1]+1)
+ mean_s_arr = L[succ].mean(0); std_s_arr = L[succ].std(0)
+ mean_f_arr = L[~succ].mean(0); std_f_arr = L[~succ].std(0)
+ ax.fill_between(idx, mean_s_arr-std_s_arr, mean_s_arr+std_s_arr, color="C0", alpha=0.2)
+ ax.plot(idx, mean_s_arr, "C0-o", label=f"succ n={succ.sum()}")
+ ax.fill_between(idx, mean_f_arr-std_f_arr, mean_f_arr+std_f_arr, color="C3", alpha=0.2)
+ ax.plot(idx, mean_f_arr, "C3-o", label=f"fail n={(~succ).sum()}")
+ ax.axhline(0, color="k", ls=":", lw=0.6)
+ ax.set_title(f"step_{s} (epoch≈{int(s*20000/26041)}, acc={succ.mean():.2%})")
+ ax.set_xlabel("λ index")
+ ax.legend(fontsize=8, loc="best")
+ ax.grid(alpha=0.3)
+axes[0,0].set_ylabel(r"$\lambda_i$")
+axes[1,0].set_ylabel(r"$\lambda_i$")
+fig.tight_layout()
+fig.savefig(f"{OUT_DIR}/spectrum_per_ckpt.png", dpi=130)
+plt.close()
+
+# --- summary table ---
+print(f"{'epoch':>6} {'step':>6} {'acc':>7} {'λ_1(s)':>9} {'λ_1(f)':>9} {'gap':>7} "
+ f"{'mean(s)':>9} {'mean(f)':>9} {'mean_gap':>9}")
+for s, ep in zip(CKPT_STEPS, CKPT_EPOCHS):
+ d = datas[s]
+ succ = d["exact_correct"] > 0.5
+ L = d["lyap_spec"]
+ l1s, l1f = L[succ,0].mean(), L[~succ,0].mean()
+ ms, mf = L[succ].mean(), L[~succ].mean()
+ print(f"{ep:>6} {s:>6} {succ.mean():>7.2%} {l1s:+9.3f} {l1f:+9.3f} {l1f-l1s:+7.3f} "
+ f"{ms:+9.3f} {mf:+9.3f} {mf-ms:+9.3f}")
+
+print(f"\nplots → {OUT_DIR}/")