diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analyze_joint.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analyze_joint.py')
| -rw-r--r-- | research/flossing/analyze_joint.py | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/research/flossing/analyze_joint.py b/research/flossing/analyze_joint.py new file mode 100644 index 0000000..1b65d0b --- /dev/null +++ b/research/flossing/analyze_joint.py @@ -0,0 +1,103 @@ +"""Analyze the JOINT-tangent diagnostic (1024 samples).""" +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from scipy import stats +import glob, os + +ROOT = "/home/yurenh2/rrm/research/flossing" +OUT = f"{ROOT}/plots_joint" +os.makedirs(OUT, exist_ok=True) + +# Merge +files = sorted(glob.glob(f"{ROOT}/diag_joint_1k_shard*.npz")) +m = {} +for f in files: + d = np.load(f) + for k in d.files: m.setdefault(k, []).append(d[k]) +for k in list(m.keys()): m[k] = np.concatenate(m[k], 0) +np.savez_compressed(f"{ROOT}/diag_joint_1k.npz", **m) +print(f"merged: N={len(m['exact_correct'])} acc={m['exact_correct'].mean():.4f}") + +ls = m["lyap_spec"] # (N, K) +exact = m["exact_correct"].astype(bool) +N, K = ls.shape +succ = ls[exact]; fail = ls[~exact] + +print(f"\n--per-λ stats (joint tangent, k={K}) --") +print(f"{'i':>3} {'mean_all':>10} {'mean_succ':>10} {'mean_fail':>10} {'Δ':>8} {'d':>7} {'auroc':>7} {'p_t':>9}") +for i in range(K): + li = ls[:, i]; s_ = li[exact]; f_ = li[~exact] + pooled = np.sqrt(((s_.size-1)*s_.var() + (f_.size-1)*f_.var()) / (s_.size+f_.size-2)) + dc = (s_.mean() - f_.mean()) / pooled + delta = f_.mean() - s_.mean() + auc = stats.mannwhitneyu(f_, s_, alternative="greater").statistic / (f_.size * s_.size) + t, p = stats.ttest_ind(s_, f_, equal_var=False) + print(f"{i+1:>3} {li.mean():+10.4f} {s_.mean():+10.4f} {f_.mean():+10.4f} {delta:+8.4f} {abs(dc):>7.3f} {auc:>7.3f} {p:>9.2e}") + +def auc(score, target_pos): + return stats.mannwhitneyu(score[target_pos], score[~target_pos], alternative="greater").statistic / (target_pos.sum() * (~target_pos).sum()) + +print(f"\n--combined predictors AUROC (failure)--") +for label, score in [ + ("λ_1", ls[:, 0]), + ("λ_K (most-neg)", ls[:, -1]), + ("mean λ", ls.mean(1)), + ("max λ", ls.max(1)), + ("# of positive λ", (ls > 0).sum(1).astype(float)), +]: + print(f" {label:18s}: {auc(score, ~exact):.4f}") + +# --- plot 1: spectrum mean ± std --- +fig, ax = plt.subplots(1, 1, figsize=(7, 5)) +mean_s = succ.mean(0); std_s = succ.std(0) +mean_f = fail.mean(0); std_f = fail.std(0) +x = np.arange(1, K+1) +ax.fill_between(x, mean_s-std_s, mean_s+std_s, color="C0", alpha=0.25) +ax.plot(x, mean_s, "C0-o", label=f"success n={succ.shape[0]}", lw=2) +ax.fill_between(x, mean_f-std_f, mean_f+std_f, color="C3", alpha=0.25) +ax.plot(x, mean_f, "C3-o", label=f"failure n={fail.shape[0]}", lw=2) +ax.axhline(0, color="k", ls="--", lw=1, label="critical λ=0") +ax.set_xlabel(r"Lyapunov index $i$") +ax.set_ylabel(r"$\lambda_i$ (per inner cycle, joint $(z_H,z_L)$ tangent)") +ax.set_title(f"Joint-tangent top-{K} Lyapunov spectrum @ step_26040 (N={N}, acc={exact.mean():.3f})") +ax.legend() +ax.grid(alpha=0.3) +fig.tight_layout() +fig.savefig(f"{OUT}/lyap_spectrum_joint.png", dpi=130) +plt.close() + +# --- plot 2: λ_1 histogram with critical line --- +fig, ax = plt.subplots(1, 1, figsize=(7, 4)) +lo, hi = ls[:,0].min(), ls[:,0].max() +bins = np.linspace(lo, hi, 60) +ax.hist(succ[:,0], bins=bins, alpha=0.55, label=f"succ", color="C0", density=True) +ax.hist(fail[:,0], bins=bins, alpha=0.55, label=f"fail", color="C3", density=True) +ax.axvline(0, color="k", ls="--", lw=1, label="λ=0") +ax.axvline(succ[:,0].mean(), color="C0", ls=":", lw=1) +ax.axvline(fail[:,0].mean(), color="C3", ls=":", lw=1) +ax.set_xlabel(r"$\lambda_1$") +ax.set_ylabel("density") +ax.set_title(f"Joint $\\lambda_1$ distribution (AUROC={auc(ls[:,0], ~exact):.3f})") +ax.legend() +fig.tight_layout() +fig.savefig(f"{OUT}/lyap1_hist_joint.png", dpi=130) +plt.close() + +# --- plot 3: # positive Lyapunov per sample (each sample, how many of top-8 are positive?) --- +n_pos_per_sample = (ls > 0).sum(1) +fig, ax = plt.subplots(1, 1, figsize=(6.5, 4)) +bins = np.arange(K+2) - 0.5 +ax.hist(n_pos_per_sample[exact], bins=bins, alpha=0.55, label="success", color="C0", density=True, align="mid") +ax.hist(n_pos_per_sample[~exact], bins=bins, alpha=0.55, label="failure", color="C3", density=True, align="mid") +ax.set_xticks(range(K+1)) +ax.set_xlabel("# of positive (expansive) Lyapunov exponents in top-8") +ax.set_ylabel("fraction of samples in group") +ax.set_title("Counts of expansive modes per sample\n(joint dynamics has unstable manifold → no fixed point)") +ax.legend() +fig.tight_layout() +fig.savefig(f"{OUT}/n_positive_modes.png", dpi=130) +plt.close() + +print(f"\nplots → {OUT}/") |
