summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_joint.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/analyze_joint.py')
-rw-r--r--research/flossing/analyze_joint.py103
1 files changed, 103 insertions, 0 deletions
diff --git a/research/flossing/analyze_joint.py b/research/flossing/analyze_joint.py
new file mode 100644
index 0000000..1b65d0b
--- /dev/null
+++ b/research/flossing/analyze_joint.py
@@ -0,0 +1,103 @@
+"""Analyze the JOINT-tangent diagnostic (1024 samples)."""
+import numpy as np
+import matplotlib
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+from scipy import stats
+import glob, os
+
+ROOT = "/home/yurenh2/rrm/research/flossing"
+OUT = f"{ROOT}/plots_joint"
+os.makedirs(OUT, exist_ok=True)
+
+# Merge
+files = sorted(glob.glob(f"{ROOT}/diag_joint_1k_shard*.npz"))
+m = {}
+for f in files:
+ d = np.load(f)
+ for k in d.files: m.setdefault(k, []).append(d[k])
+for k in list(m.keys()): m[k] = np.concatenate(m[k], 0)
+np.savez_compressed(f"{ROOT}/diag_joint_1k.npz", **m)
+print(f"merged: N={len(m['exact_correct'])} acc={m['exact_correct'].mean():.4f}")
+
+ls = m["lyap_spec"] # (N, K)
+exact = m["exact_correct"].astype(bool)
+N, K = ls.shape
+succ = ls[exact]; fail = ls[~exact]
+
+print(f"\n--per-λ stats (joint tangent, k={K}) --")
+print(f"{'i':>3} {'mean_all':>10} {'mean_succ':>10} {'mean_fail':>10} {'Δ':>8} {'d':>7} {'auroc':>7} {'p_t':>9}")
+for i in range(K):
+ li = ls[:, i]; s_ = li[exact]; f_ = li[~exact]
+ pooled = np.sqrt(((s_.size-1)*s_.var() + (f_.size-1)*f_.var()) / (s_.size+f_.size-2))
+ dc = (s_.mean() - f_.mean()) / pooled
+ delta = f_.mean() - s_.mean()
+ auc = stats.mannwhitneyu(f_, s_, alternative="greater").statistic / (f_.size * s_.size)
+ t, p = stats.ttest_ind(s_, f_, equal_var=False)
+ print(f"{i+1:>3} {li.mean():+10.4f} {s_.mean():+10.4f} {f_.mean():+10.4f} {delta:+8.4f} {abs(dc):>7.3f} {auc:>7.3f} {p:>9.2e}")
+
+def auc(score, target_pos):
+ return stats.mannwhitneyu(score[target_pos], score[~target_pos], alternative="greater").statistic / (target_pos.sum() * (~target_pos).sum())
+
+print(f"\n--combined predictors AUROC (failure)--")
+for label, score in [
+ ("λ_1", ls[:, 0]),
+ ("λ_K (most-neg)", ls[:, -1]),
+ ("mean λ", ls.mean(1)),
+ ("max λ", ls.max(1)),
+ ("# of positive λ", (ls > 0).sum(1).astype(float)),
+]:
+ print(f" {label:18s}: {auc(score, ~exact):.4f}")
+
+# --- plot 1: spectrum mean ± std ---
+fig, ax = plt.subplots(1, 1, figsize=(7, 5))
+mean_s = succ.mean(0); std_s = succ.std(0)
+mean_f = fail.mean(0); std_f = fail.std(0)
+x = np.arange(1, K+1)
+ax.fill_between(x, mean_s-std_s, mean_s+std_s, color="C0", alpha=0.25)
+ax.plot(x, mean_s, "C0-o", label=f"success n={succ.shape[0]}", lw=2)
+ax.fill_between(x, mean_f-std_f, mean_f+std_f, color="C3", alpha=0.25)
+ax.plot(x, mean_f, "C3-o", label=f"failure n={fail.shape[0]}", lw=2)
+ax.axhline(0, color="k", ls="--", lw=1, label="critical λ=0")
+ax.set_xlabel(r"Lyapunov index $i$")
+ax.set_ylabel(r"$\lambda_i$ (per inner cycle, joint $(z_H,z_L)$ tangent)")
+ax.set_title(f"Joint-tangent top-{K} Lyapunov spectrum @ step_26040 (N={N}, acc={exact.mean():.3f})")
+ax.legend()
+ax.grid(alpha=0.3)
+fig.tight_layout()
+fig.savefig(f"{OUT}/lyap_spectrum_joint.png", dpi=130)
+plt.close()
+
+# --- plot 2: λ_1 histogram with critical line ---
+fig, ax = plt.subplots(1, 1, figsize=(7, 4))
+lo, hi = ls[:,0].min(), ls[:,0].max()
+bins = np.linspace(lo, hi, 60)
+ax.hist(succ[:,0], bins=bins, alpha=0.55, label=f"succ", color="C0", density=True)
+ax.hist(fail[:,0], bins=bins, alpha=0.55, label=f"fail", color="C3", density=True)
+ax.axvline(0, color="k", ls="--", lw=1, label="λ=0")
+ax.axvline(succ[:,0].mean(), color="C0", ls=":", lw=1)
+ax.axvline(fail[:,0].mean(), color="C3", ls=":", lw=1)
+ax.set_xlabel(r"$\lambda_1$")
+ax.set_ylabel("density")
+ax.set_title(f"Joint $\\lambda_1$ distribution (AUROC={auc(ls[:,0], ~exact):.3f})")
+ax.legend()
+fig.tight_layout()
+fig.savefig(f"{OUT}/lyap1_hist_joint.png", dpi=130)
+plt.close()
+
+# --- plot 3: # positive Lyapunov per sample (each sample, how many of top-8 are positive?) ---
+n_pos_per_sample = (ls > 0).sum(1)
+fig, ax = plt.subplots(1, 1, figsize=(6.5, 4))
+bins = np.arange(K+2) - 0.5
+ax.hist(n_pos_per_sample[exact], bins=bins, alpha=0.55, label="success", color="C0", density=True, align="mid")
+ax.hist(n_pos_per_sample[~exact], bins=bins, alpha=0.55, label="failure", color="C3", density=True, align="mid")
+ax.set_xticks(range(K+1))
+ax.set_xlabel("# of positive (expansive) Lyapunov exponents in top-8")
+ax.set_ylabel("fraction of samples in group")
+ax.set_title("Counts of expansive modes per sample\n(joint dynamics has unstable manifold → no fixed point)")
+ax.legend()
+fig.tight_layout()
+fig.savefig(f"{OUT}/n_positive_modes.png", dpi=130)
+plt.close()
+
+print(f"\nplots → {OUT}/")