"""Plot λ-spectrum and accuracy evolution across training checkpoints.""" import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt CKPT_STEPS = [2604, 7812, 13020, 18228, 20832, 26040] CKPT_EPOCHS = [int(s * 20000 / 26041) for s in CKPT_STEPS] # actual epoch (rounded) ROOT = "/home/yurenh2/rrm/research/flossing/ckpt_evolution" OUT_DIR = "/home/yurenh2/rrm/research/flossing/plots_evolution" import os; os.makedirs(OUT_DIR, exist_ok=True) datas = {} for s in CKPT_STEPS: d = np.load(f"{ROOT}/step_{s}.npz") datas[s] = d # --- plot 1: λ_1 vs training step, split by success/failure --- fig, axes = plt.subplots(1, 2, figsize=(13, 5)) ax = axes[0] lam1_s_mean, lam1_f_mean = [], [] lam1_s_std, lam1_f_std = [], [] mean_s, mean_f = [], [] mean_s_std, mean_f_std = [], [] accs = [] for s in CKPT_STEPS: d = datas[s] succ = d["exact_correct"] > 0.5 L = d["lyap_spec"] lam1_s_mean.append(L[succ,0].mean()); lam1_s_std.append(L[succ,0].std()) lam1_f_mean.append(L[~succ,0].mean()); lam1_f_std.append(L[~succ,0].std()) mean_s.append(L[succ].mean(axis=1).mean()); mean_s_std.append(L[succ].mean(axis=1).std()) mean_f.append(L[~succ].mean(axis=1).mean()); mean_f_std.append(L[~succ].mean(axis=1).std()) accs.append(succ.mean()) x = CKPT_EPOCHS ax.errorbar(x, lam1_s_mean, yerr=lam1_s_std, fmt="C0-o", capsize=4, label=r"$\lambda_1$ success") ax.errorbar(x, lam1_f_mean, yerr=lam1_f_std, fmt="C3-o", capsize=4, label=r"$\lambda_1$ failure") ax.errorbar(x, mean_s, yerr=mean_s_std, fmt="C0--s", alpha=0.6, capsize=3, label=r"mean$_{1:8}$ success") ax.errorbar(x, mean_f, yerr=mean_f_std, fmt="C3--s", alpha=0.6, capsize=3, label=r"mean$_{1:8}$ failure") ax.axhline(0, color="k", ls=":", lw=0.6) ax.set_xlabel("training epoch") ax.set_ylabel(r"Lyapunov exponent (per inner cycle)") ax.set_title("λ evolution during training\n(error bars = within-group std)") ax.legend(loc="best", fontsize=9) ax.grid(alpha=0.3) ax2 = axes[1] gap = np.array(lam1_f_mean) - np.array(lam1_s_mean) ax2.plot(x, gap, "C2-o", label=r"$\lambda_1^{fail} - \lambda_1^{succ}$", lw=2) ax2.axhline(0, color="k", ls=":", lw=0.6) ax2b = ax2.twinx() ax2b.plot(x, accs, "C4-^", label="exact accuracy") ax2b.set_ylabel("test exact accuracy", color="C4") ax2b.tick_params(axis="y", labelcolor="C4") ax2.set_xlabel("training epoch") ax2.set_ylabel("λ gap (fail − succ)", color="C2") ax2.tick_params(axis="y", labelcolor="C2") ax2.set_title("succ/fail λ gap vs accuracy") ax2.legend(loc="upper left") ax2b.legend(loc="lower right") ax2.grid(alpha=0.3) fig.tight_layout() fig.savefig(f"{OUT_DIR}/lyap_vs_training.png", dpi=130) plt.close() # --- plot 2: per-checkpoint λ spectrum (all 8) shape --- fig, axes = plt.subplots(2, 3, figsize=(15, 7), sharey=True) for ax, s in zip(axes.flat, CKPT_STEPS): d = datas[s] succ = d["exact_correct"] > 0.5 L = d["lyap_spec"] idx = np.arange(1, L.shape[1]+1) mean_s_arr = L[succ].mean(0); std_s_arr = L[succ].std(0) mean_f_arr = L[~succ].mean(0); std_f_arr = L[~succ].std(0) ax.fill_between(idx, mean_s_arr-std_s_arr, mean_s_arr+std_s_arr, color="C0", alpha=0.2) ax.plot(idx, mean_s_arr, "C0-o", label=f"succ n={succ.sum()}") ax.fill_between(idx, mean_f_arr-std_f_arr, mean_f_arr+std_f_arr, color="C3", alpha=0.2) ax.plot(idx, mean_f_arr, "C3-o", label=f"fail n={(~succ).sum()}") ax.axhline(0, color="k", ls=":", lw=0.6) ax.set_title(f"step_{s} (epoch≈{int(s*20000/26041)}, acc={succ.mean():.2%})") ax.set_xlabel("λ index") ax.legend(fontsize=8, loc="best") ax.grid(alpha=0.3) axes[0,0].set_ylabel(r"$\lambda_i$") axes[1,0].set_ylabel(r"$\lambda_i$") fig.tight_layout() fig.savefig(f"{OUT_DIR}/spectrum_per_ckpt.png", dpi=130) plt.close() # --- summary table --- print(f"{'epoch':>6} {'step':>6} {'acc':>7} {'λ_1(s)':>9} {'λ_1(f)':>9} {'gap':>7} " f"{'mean(s)':>9} {'mean(f)':>9} {'mean_gap':>9}") for s, ep in zip(CKPT_STEPS, CKPT_EPOCHS): d = datas[s] succ = d["exact_correct"] > 0.5 L = d["lyap_spec"] l1s, l1f = L[succ,0].mean(), L[~succ,0].mean() ms, mf = L[succ].mean(), L[~succ].mean() print(f"{ep:>6} {s:>6} {succ.mean():>7.2%} {l1s:+9.3f} {l1f:+9.3f} {l1f-l1s:+7.3f} " f"{ms:+9.3f} {mf:+9.3f} {mf-ms:+9.3f}") print(f"\nplots → {OUT_DIR}/")