summaryrefslogtreecommitdiff
path: root/research/flossing/plot_trm_chaos_trajectory.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/plot_trm_chaos_trajectory.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/plot_trm_chaos_trajectory.py')
-rw-r--r--research/flossing/plot_trm_chaos_trajectory.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/research/flossing/plot_trm_chaos_trajectory.py b/research/flossing/plot_trm_chaos_trajectory.py
new file mode 100644
index 0000000..da909c1
--- /dev/null
+++ b/research/flossing/plot_trm_chaos_trajectory.py
@@ -0,0 +1,67 @@
+"""Plot TRM Lyapunov trajectory across training checkpoints."""
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+
+ROOT = "/home/yurenh2/rrm/research/flossing"
+CKPTS = [
+ (26041, 5000),
+ (52082, 10000),
+ (78123, 15000),
+ (104164, 20000),
+ (130205, 25000),
+ (156246, 30000),
+ (182287, 35000),
+ (208328, 40000),
+ (234369, 45000),
+ (260410, 50000),
+]
+
+acc_list, succ_lam, fail_lam, dlam = [], [], [], []
+for step, epoch in CKPTS:
+ d = np.load(f"{ROOT}/diag_trm_singleGPU_step{step}_512.npz")
+ succ = d["exact_correct"] > 0.5
+ lam_1 = d["lyap_spec"][:, 0]
+ acc_list.append(succ.mean())
+ succ_lam.append(lam_1[succ].mean())
+ fail_lam.append(lam_1[~succ].mean())
+ dlam.append(lam_1[~succ].mean() - lam_1[succ].mean())
+
+epochs = [e for _, e in CKPTS]
+
+fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
+
+ax = axes[0]
+ax.plot(epochs, acc_list, "ko-", lw=2, ms=8)
+ax.set_xlabel("epoch"); ax.set_ylabel("exact_acc on 512 test")
+ax.set_title("TRM accuracy vs training")
+ax.grid(alpha=0.3)
+for e, a in zip(epochs, acc_list):
+ ax.text(e, a + 0.01, f"{a:.3f}", ha="center", fontsize=9)
+
+ax = axes[1]
+ax.plot(epochs, succ_lam, "C2o-", lw=2, ms=8, label="succ λ_1")
+ax.plot(epochs, fail_lam, "C3o-", lw=2, ms=8, label="fail λ_1")
+ax.axhline(0, color="k", lw=0.5, ls=":")
+ax.set_xlabel("epoch"); ax.set_ylabel(r"$\lambda_1$ (top joint Lyap)")
+ax.set_title("TRM Lyapunov drift: succ → criticality, fail → chaos")
+ax.legend(); ax.grid(alpha=0.3)
+
+ax = axes[2]
+ax.plot(epochs, dlam, "C0o-", lw=2, ms=8)
+ax.set_xlabel("epoch"); ax.set_ylabel(r"$\Delta\lambda$ = fail - succ")
+ax.set_title("Discrimination gap (non-monotonic)")
+ax.grid(alpha=0.3)
+for e, d in zip(epochs, dlam):
+ ax.text(e, d + 0.003, f"{d:+.3f}", ha="center", fontsize=9)
+
+fig.suptitle("TRM chaos onset trajectory (single-GPU, epoch 5K-25K, 5 checkpoints)", fontsize=12)
+fig.tight_layout()
+out = f"{ROOT}/plots_trm_chaos_onset.png"
+fig.savefig(out, dpi=130)
+print(f"→ {out}")
+
+print(f"\n{'epoch':>6} {'acc':>7} {'succ_λ':>9} {'fail_λ':>9} {'Δλ':>8}")
+for e, a, sl, fl, dl in zip(epochs, acc_list, succ_lam, fail_lam, dlam):
+ print(f"{e:>6} {a:>7.3f} {sl:>+9.4f} {fl:>+9.4f} {dl:>+8.4f}")