"""Plot TRM Lyapunov trajectory across training checkpoints.""" import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt ROOT = "/home/yurenh2/rrm/research/flossing" CKPTS = [ (26041, 5000), (52082, 10000), (78123, 15000), (104164, 20000), (130205, 25000), (156246, 30000), (182287, 35000), (208328, 40000), (234369, 45000), (260410, 50000), ] acc_list, succ_lam, fail_lam, dlam = [], [], [], [] for step, epoch in CKPTS: d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz") succ = d["exact_correct"] > 0.5 lam_1 = d["lyap_spec"][:, 0] acc_list.append(succ.mean()) succ_lam.append(lam_1[succ].mean()) fail_lam.append(lam_1[~succ].mean()) dlam.append(lam_1[~succ].mean() - lam_1[succ].mean()) epochs = [e for _, e in CKPTS] fig, axes = plt.subplots(1, 3, figsize=(15, 4.5)) ax = axes[0] ax.plot(epochs, acc_list, "ko-", lw=2, ms=8) ax.set_xlabel("epoch"); ax.set_ylabel("exact_acc on 512 test") ax.set_title("TRM accuracy vs training") ax.grid(alpha=0.3) for e, a in zip(epochs, acc_list): ax.text(e, a + 0.01, f"{a:.3f}", ha="center", fontsize=9) ax = axes[1] ax.plot(epochs, succ_lam, "C2o-", lw=2, ms=8, label="succ λ_1") ax.plot(epochs, fail_lam, "C3o-", lw=2, ms=8, label="fail λ_1") ax.axhline(0, color="k", lw=0.5, ls=":") ax.set_xlabel("epoch"); ax.set_ylabel(r"$\lambda_1$ (top joint Lyap)") ax.set_title("TRM Lyapunov drift: succ → criticality, fail → chaos") ax.legend(); ax.grid(alpha=0.3) ax = axes[2] ax.plot(epochs, dlam, "C0o-", lw=2, ms=8) ax.set_xlabel("epoch"); ax.set_ylabel(r"$\Delta\lambda$ = fail - succ") ax.set_title("Discrimination gap (non-monotonic)") ax.grid(alpha=0.3) for e, d in zip(epochs, dlam): ax.text(e, d + 0.003, f"{d:+.3f}", ha="center", fontsize=9) fig.suptitle("TRM chaos onset trajectory (single-GPU, epoch 5K-25K, 5 checkpoints)", fontsize=12) fig.tight_layout() out = f"{ROOT}/plots_trm_chaos_onset.png" fig.savefig(out, dpi=130) print(f"→ {out}") print(f"\n{'epoch':>6} {'acc':>7} {'succ_λ':>9} {'fail_λ':>9} {'Δλ':>8}") for e, a, sl, fl, dl in zip(epochs, acc_list, succ_lam, fail_lam, dlam): print(f"{e:>6} {a:>7.3f} {sl:>+9.4f} {fl:>+9.4f} {dl:>+8.4f}")