"""Analyze the JOINT-tangent diagnostic (1024 samples).""" import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from scipy import stats import glob, os ROOT = "/home/yurenh2/rrm/research/flossing" OUT = f"{ROOT}/plots_joint" os.makedirs(OUT, exist_ok=True) # Merge files = sorted(glob.glob(f"{ROOT}/diag_joint_1k_shard*.npz")) m = {} for f in files: d = np.load(f) for k in d.files: m.setdefault(k, []).append(d[k]) for k in list(m.keys()): m[k] = np.concatenate(m[k], 0) np.savez_compressed(f"{ROOT}/diag_joint_1k.npz", **m) print(f"merged: N={len(m['exact_correct'])} acc={m['exact_correct'].mean():.4f}") ls = m["lyap_spec"] # (N, K) exact = m["exact_correct"].astype(bool) N, K = ls.shape succ = ls[exact]; fail = ls[~exact] print(f"\n--per-λ stats (joint tangent, k={K}) --") print(f"{'i':>3} {'mean_all':>10} {'mean_succ':>10} {'mean_fail':>10} {'Δ':>8} {'d':>7} {'auroc':>7} {'p_t':>9}") for i in range(K): li = ls[:, i]; s_ = li[exact]; f_ = li[~exact] pooled = np.sqrt(((s_.size-1)*s_.var() + (f_.size-1)*f_.var()) / (s_.size+f_.size-2)) dc = (s_.mean() - f_.mean()) / pooled delta = f_.mean() - s_.mean() auc = stats.mannwhitneyu(f_, s_, alternative="greater").statistic / (f_.size * s_.size) t, p = stats.ttest_ind(s_, f_, equal_var=False) print(f"{i+1:>3} {li.mean():+10.4f} {s_.mean():+10.4f} {f_.mean():+10.4f} {delta:+8.4f} {abs(dc):>7.3f} {auc:>7.3f} {p:>9.2e}") def auc(score, target_pos): return stats.mannwhitneyu(score[target_pos], score[~target_pos], alternative="greater").statistic / (target_pos.sum() * (~target_pos).sum()) print(f"\n--combined predictors AUROC (failure)--") for label, score in [ ("λ_1", ls[:, 0]), ("λ_K (most-neg)", ls[:, -1]), ("mean λ", ls.mean(1)), ("max λ", ls.max(1)), ("# of positive λ", (ls > 0).sum(1).astype(float)), ]: print(f" {label:18s}: {auc(score, ~exact):.4f}") # --- plot 1: spectrum mean ± std --- fig, ax = plt.subplots(1, 1, figsize=(7, 5)) mean_s = succ.mean(0); std_s = succ.std(0) mean_f = fail.mean(0); std_f = fail.std(0) x = np.arange(1, K+1) ax.fill_between(x, mean_s-std_s, mean_s+std_s, color="C0", alpha=0.25) ax.plot(x, mean_s, "C0-o", label=f"success n={succ.shape[0]}", lw=2) ax.fill_between(x, mean_f-std_f, mean_f+std_f, color="C3", alpha=0.25) ax.plot(x, mean_f, "C3-o", label=f"failure n={fail.shape[0]}", lw=2) ax.axhline(0, color="k", ls="--", lw=1, label="critical λ=0") ax.set_xlabel(r"Lyapunov index $i$") ax.set_ylabel(r"$\lambda_i$ (per inner cycle, joint $(z_H,z_L)$ tangent)") ax.set_title(f"Joint-tangent top-{K} Lyapunov spectrum @ step_26040 (N={N}, acc={exact.mean():.3f})") ax.legend() ax.grid(alpha=0.3) fig.tight_layout() fig.savefig(f"{OUT}/lyap_spectrum_joint.png", dpi=130) plt.close() # --- plot 2: λ_1 histogram with critical line --- fig, ax = plt.subplots(1, 1, figsize=(7, 4)) lo, hi = ls[:,0].min(), ls[:,0].max() bins = np.linspace(lo, hi, 60) ax.hist(succ[:,0], bins=bins, alpha=0.55, label=f"succ", color="C0", density=True) ax.hist(fail[:,0], bins=bins, alpha=0.55, label=f"fail", color="C3", density=True) ax.axvline(0, color="k", ls="--", lw=1, label="λ=0") ax.axvline(succ[:,0].mean(), color="C0", ls=":", lw=1) ax.axvline(fail[:,0].mean(), color="C3", ls=":", lw=1) ax.set_xlabel(r"$\lambda_1$") ax.set_ylabel("density") ax.set_title(f"Joint $\\lambda_1$ distribution (AUROC={auc(ls[:,0], ~exact):.3f})") ax.legend() fig.tight_layout() fig.savefig(f"{OUT}/lyap1_hist_joint.png", dpi=130) plt.close() # --- plot 3: # positive Lyapunov per sample (each sample, how many of top-8 are positive?) --- n_pos_per_sample = (ls > 0).sum(1) fig, ax = plt.subplots(1, 1, figsize=(6.5, 4)) bins = np.arange(K+2) - 0.5 ax.hist(n_pos_per_sample[exact], bins=bins, alpha=0.55, label="success", color="C0", density=True, align="mid") ax.hist(n_pos_per_sample[~exact], bins=bins, alpha=0.55, label="failure", color="C3", density=True, align="mid") ax.set_xticks(range(K+1)) ax.set_xlabel("# of positive (expansive) Lyapunov exponents in top-8") ax.set_ylabel("fraction of samples in group") ax.set_title("Counts of expansive modes per sample\n(joint dynamics has unstable manifold → no fixed point)") ax.legend() fig.tight_layout() fig.savefig(f"{OUT}/n_positive_modes.png", dpi=130) plt.close() print(f"\nplots → {OUT}/")