"""2x2 analysis: (terminal convergence) x (answer correctness), per-example FTLE per cell. Inputs: existing diagnostic npz files produced by diagnose_hrm*.py / diagnose_trm_joint.py: drift_zH/drift_zL (N,16) per-ACT-step state displacement norms, lyap_spec (N,8) full-window joint FTLE spectrum, exact_correct (N,), token_acc, halted_at, q_halt/q_continue (N,16), idx. Convergence metric (measurement choice, reported alongside robustness sweep): d_late = mean(drift_zH[:, -4:]) (late-trajectory z_H velocity, ACT steps 13-16) primary threshold tau = Otsu on log10(d_late) pooled per dataset; sensitivity: tau swept over pooled percentiles 5..95. Cells: A = converged & correct, B = converged & wrong, C = non-converged & correct, D = non-converged & wrong. Outputs (in this directory): results_.json, cells_.csv, sweep_.csv, fig__{drift_hist,lyap_by_cell,scatter,spectrum}.png, evolution_{hrm,trm}.{csv,png}, results.md (combined human-readable summary). Observational only: this script reports counts, distributions and rank statistics; it does not test mechanisms. """ from __future__ import annotations import json from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np HERE = Path(__file__).resolve().parent FLOSS = HERE.parent LATE_K = 4 # ACT steps used for late drift def otsu_threshold(x: np.ndarray, nbins: int = 256) -> float: h, edges = np.histogram(x, bins=nbins) h = h.astype(np.float64) centers = 0.5 * (edges[:-1] + edges[1:]) w = h.sum() if w == 0: return float(np.median(x)) p = h / w omega = np.cumsum(p) mu = np.cumsum(p * centers) mu_t = mu[-1] denom = omega * (1.0 - omega) denom[denom <= 0] = np.nan sigma_b2 = (mu_t * omega - mu) ** 2 / denom k = np.nanargmax(sigma_b2) return float(centers[k]) def auc_rank(score: np.ndarray, label: np.ndarray) -> float: """AUC of `score` for predicting label==1 (rank-based, ties averaged).""" pos = score[label == 1] neg = score[label == 0] if len(pos) == 0 or len(neg) == 0: return float("nan") allv = np.concatenate([pos, neg]) order = np.argsort(allv, kind="mergesort") ranks = np.empty_like(order, dtype=np.float64) ranks[order] = np.arange(1, len(allv) + 1) # average ranks for ties sv = allv[order] i = 0 while i < len(sv): j = i while j + 1 < len(sv) and sv[j + 1] == sv[i]: j += 1 if j > i: ranks[order[i : j + 1]] = ranks[order[i : j + 1]].mean() i = j + 1 r_pos = ranks[: len(pos)].sum() return float((r_pos - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg))) def cell_stats(lyap: np.ndarray, tok: np.ndarray, halted: np.ndarray, mask: np.ndarray) -> dict: if mask.sum() == 0: return {"n": 0} l1 = lyap[mask, 0] return { "n": int(mask.sum()), "lam1_median": float(np.median(l1)), "lam1_mean": float(l1.mean()), "lam1_iqr": [float(np.percentile(l1, 25)), float(np.percentile(l1, 75))], "lam8_median": float(np.median(lyap[mask, -1])), "spectrum_median": [float(np.median(lyap[mask, i])) for i in range(lyap.shape[1])], "token_acc_median": float(np.median(tok[mask])), "halted_at_median": float(np.median(halted[mask])), } def analyze(npz_path: Path, tag: str, make_figs: bool = True) -> dict: d = np.load(npz_path) lyap = d["lyap_spec"].astype(np.float64) correct = d["exact_correct"].astype(int) tok = d["token_acc"].astype(np.float64) halted = d["halted_at"].astype(np.float64) drift_h = d["drift_zH"].astype(np.float64) drift_l = d["drift_zL"].astype(np.float64) d_late = drift_h[:, -LATE_K:].mean(axis=1) d_late_l = drift_l[:, -LATE_K:].mean(axis=1) logd = np.log10(np.clip(d_late, 1e-12, None)) tau = otsu_threshold(logd) conv = logd < tau cells = { "A_conv_correct": (conv) & (correct == 1), "B_conv_wrong": (conv) & (correct == 0), "C_nonconv_correct": (~conv) & (correct == 1), "D_nonconv_wrong": (~conv) & (correct == 0), } res = { "npz": str(npz_path), "n": int(len(correct)), "exact_acc": float(correct.mean()), "late_drift_def": f"mean(drift_zH[:, -{LATE_K}:])", "otsu_tau_log10": tau, "frac_converged": float(conv.mean()), "cells": {k: cell_stats(lyap, tok, halted, m) for k, m in cells.items()}, "mixture": { "wrong_that_converged": float( cells["B_conv_wrong"].sum() / max((correct == 0).sum(), 1) ), "correct_that_nonconverged": float( cells["C_nonconv_correct"].sum() / max((correct == 1).sum(), 1) ), }, "contrasts": { "dlam1_correct_minus_wrong_overall": float( np.median(lyap[correct == 1, 0]) - np.median(lyap[correct == 0, 0]) ), "dlam1_within_converged": float( np.median(lyap[cells["A_conv_correct"], 0]) - np.median(lyap[cells["B_conv_wrong"], 0]) ) if cells["B_conv_wrong"].sum() > 0 and cells["A_conv_correct"].sum() > 0 else float("nan"), "dlam1_within_nonconverged": float( np.median(lyap[cells["C_nonconv_correct"], 0]) - np.median(lyap[cells["D_nonconv_wrong"], 0]) ) if cells["C_nonconv_correct"].sum() > 0 and cells["D_nonconv_wrong"].sum() > 0 else float("nan"), "dlam1_wrong_conv_minus_wrong_nonconv": float( np.median(lyap[cells["B_conv_wrong"], 0]) - np.median(lyap[cells["D_nonconv_wrong"], 0]) ) if cells["B_conv_wrong"].sum() > 0 and cells["D_nonconv_wrong"].sum() > 0 else float("nan"), }, "auc": { "neg_lam1_predicts_correct_overall": auc_rank(-lyap[:, 0], correct), "neg_lam1_predicts_correct_within_conv": auc_rank(-lyap[conv, 0], correct[conv]), "neg_lam1_predicts_correct_within_nonconv": auc_rank(-lyap[~conv, 0], correct[~conv]), "neg_logdrift_predicts_correct": auc_rank(-logd, correct), "neg_lam1_predicts_converged": auc_rank(-lyap[:, 0], conv.astype(int)), }, } # threshold sensitivity sweep sweep_rows = [] for pct in range(5, 96, 5): t = np.percentile(logd, pct) c = logd < t row = { "pct": pct, "tau": float(t), "nA": int((c & (correct == 1)).sum()), "nB": int((c & (correct == 0)).sum()), "nC": int((~c & (correct == 1)).sum()), "nD": int((~c & (correct == 0)).sum()), } for nm, m in [ ("lam1_med_B", c & (correct == 0)), ("lam1_med_D", ~c & (correct == 0)), ]: row[nm] = float(np.median(lyap[m, 0])) if m.sum() > 0 else float("nan") sweep_rows.append(row) sweep_csv = HERE / f"sweep_{tag}.csv" with sweep_csv.open("w") as f: keys = list(sweep_rows[0].keys()) f.write(",".join(keys) + "\n") for r in sweep_rows: f.write(",".join(str(r[k]) for k in keys) + "\n") # per-cell csv with (HERE / f"cells_{tag}.csv").open("w") as f: f.write("cell,n,lam1_median,lam1_mean,lam1_q25,lam1_q75,lam8_median,token_acc_median,halted_at_median\n") for k, m in cells.items(): s = res["cells"][k] if s["n"] == 0: f.write(f"{k},0,,,,,,,\n") continue f.write( f"{k},{s['n']},{s['lam1_median']:.6f},{s['lam1_mean']:.6f}," f"{s['lam1_iqr'][0]:.6f},{s['lam1_iqr'][1]:.6f},{s['lam8_median']:.6f}," f"{s['token_acc_median']:.4f},{s['halted_at_median']:.1f}\n" ) if make_figs: colors = {"A_conv_correct": "tab:green", "B_conv_wrong": "tab:orange", "C_nonconv_correct": "tab:blue", "D_nonconv_wrong": "tab:red"} fig, ax = plt.subplots(figsize=(6, 4)) bins = np.linspace(logd.min(), logd.max(), 60) ax.hist(logd[correct == 1], bins=bins, alpha=0.55, label=f"correct (n={int(correct.sum())})", color="tab:green") ax.hist(logd[correct == 0], bins=bins, alpha=0.55, label=f"wrong (n={int((1-correct).sum())})", color="tab:red") ax.axvline(tau, color="k", ls="--", lw=1, label=f"Otsu tau={tau:.2f}") ax.set_xlabel("log10 late drift_zH (steps -4:)"); ax.set_ylabel("count") ax.set_title(f"{tag}: late-drift distribution by correctness"); ax.legend(fontsize=8) fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_drift_hist.png", dpi=150); plt.close(fig) fig, ax = plt.subplots(figsize=(6.5, 4)) for i, (k, m) in enumerate(cells.items()): if m.sum() == 0: continue y = lyap[m, 0] x = np.full(y.shape, i) + (np.random.default_rng(0).uniform(-0.18, 0.18, y.shape)) ax.plot(x, y, ".", ms=3, alpha=0.35, color=colors[k]) ax.hlines(np.median(y), i - 0.28, i + 0.28, color=colors[k], lw=2.5) ax.set_xticks(range(4)); ax.set_xticklabels([f"{k}\n(n={int(m.sum())})" for k, m in cells.items()], fontsize=7) ax.set_ylabel("lambda_1 (full-window FTLE)"); ax.axhline(0, color="gray", lw=0.6) ax.set_title(f"{tag}: lambda_1 by 2x2 cell") fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_lyap_by_cell.png", dpi=150); plt.close(fig) fig, ax = plt.subplots(figsize=(6, 4.5)) ax.scatter(logd[correct == 1], lyap[correct == 1, 0], s=5, alpha=0.4, c="tab:green", label="correct") ax.scatter(logd[correct == 0], lyap[correct == 0, 0], s=5, alpha=0.4, c="tab:red", label="wrong") ax.axvline(tau, color="k", ls="--", lw=1); ax.axhline(0, color="gray", lw=0.6) ax.set_xlabel("log10 late drift_zH"); ax.set_ylabel("lambda_1") ax.set_title(f"{tag}: drift vs lambda_1"); ax.legend(fontsize=8) fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_scatter.png", dpi=150); plt.close(fig) fig, ax = plt.subplots(figsize=(6, 4)) for k, m in cells.items(): if m.sum() < 3: continue ax.plot(range(1, lyap.shape[1] + 1), [np.median(lyap[m, i]) for i in range(lyap.shape[1])], "o-", ms=4, label=f"{k} (n={int(m.sum())})", color=colors[k]) ax.axhline(0, color="gray", lw=0.6) ax.set_xlabel("exponent index"); ax.set_ylabel("median lambda_i") ax.set_title(f"{tag}: median FTLE spectrum per cell"); ax.legend(fontsize=7) fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_spectrum.png", dpi=150); plt.close(fig) # secondary observables res["aux"] = { "late_drift_zL_corr_with_zH_log": float(np.corrcoef(logd, np.log10(np.clip(d_late_l, 1e-12, None)))[0, 1]), "q_halt_final_median_by_cell": { k: float(np.median(d["q_halt"][m, -1])) if m.sum() > 0 else float("nan") for k, m in cells.items() }, } (HERE / f"results_{tag}.json").write_text(json.dumps(res, indent=2)) return res def evolution(series: list[tuple[str, Path]], out_tag: str) -> None: rows = [] for label, p in series: if not p.exists(): continue d = np.load(p) lyap = d["lyap_spec"].astype(np.float64) correct = d["exact_correct"].astype(int) logd = np.log10(np.clip(d["drift_zH"][:, -LATE_K:].mean(axis=1), 1e-12, None)) tau = otsu_threshold(logd) conv = logd < tau row = dict(step=label, acc=float(correct.mean()), tau=tau, fA=float(((conv) & (correct == 1)).mean()), fB=float(((conv) & (correct == 0)).mean()), fC=float(((~conv) & (correct == 1)).mean()), fD=float(((~conv) & (correct == 0)).mean())) for nm, m in [("l1A", (conv) & (correct == 1)), ("l1B", (conv) & (correct == 0)), ("l1C", (~conv) & (correct == 1)), ("l1D", (~conv) & (correct == 0))]: row[nm] = float(np.median(lyap[m, 0])) if m.sum() > 2 else float("nan") rows.append(row) if not rows: return keys = list(rows[0].keys()) with (HERE / f"evolution_{out_tag}.csv").open("w") as f: f.write(",".join(keys) + "\n") for r in rows: f.write(",".join(str(r[k]) for k in keys) + "\n") fig, axes = plt.subplots(1, 2, figsize=(11, 4)) xs = range(len(rows)) for nm, c in [("fA", "tab:green"), ("fB", "tab:orange"), ("fC", "tab:blue"), ("fD", "tab:red")]: axes[0].plot(xs, [r[nm] for r in rows], "o-", label=nm, color=c) axes[0].set_xticks(list(xs)); axes[0].set_xticklabels([r["step"] for r in rows], rotation=45, fontsize=7) axes[0].set_ylabel("cell fraction"); axes[0].legend(fontsize=8); axes[0].set_title(f"{out_tag}: cell fractions") for nm, c in [("l1A", "tab:green"), ("l1B", "tab:orange"), ("l1C", "tab:blue"), ("l1D", "tab:red")]: axes[1].plot(xs, [r[nm] for r in rows], "o-", label=nm, color=c) axes[1].axhline(0, color="gray", lw=0.6) axes[1].set_xticks(list(xs)); axes[1].set_xticklabels([r["step"] for r in rows], rotation=45, fontsize=7) axes[1].set_ylabel("median lambda_1"); axes[1].legend(fontsize=8); axes[1].set_title(f"{out_tag}: per-cell lambda_1") fig.tight_layout(); fig.savefig(HERE / f"evolution_{out_tag}.png", dpi=150); plt.close(fig) def main() -> None: results = {} primary = [ ("hrm26040_n8192", FLOSS / "diag_8k.npz"), ("trm_singleGPU_step260410_n512", FLOSS / "diag_trm_singleGPU_step260410_512.npz"), ("trm_singleGPU_step130205_n512", FLOSS / "diag_trm_singleGPU_step130205_512.npz"), ("trm_step13020_n512", FLOSS / "diag_trm_step13020_512.npz"), ] for tag, p in primary: if p.exists(): results[tag] = analyze(p, tag) print(f"[done] {tag}") evolution( [(f"{s}", FLOSS / f"diag_hrm_step_{s}_512.npz") for s in [2604, 5208, 7812, 10416, 13020, 15624, 18228, 20832, 23436, 26040]], "hrm", ) evolution( [(f"{s}", FLOSS / f"diag_trm_singleGPU_step{s}_512.npz") for s in [26041, 52082, 78123, 104164, 130205, 156246, 182287, 208328, 234369, 260410]], "trm", ) # combined human-readable summary lines = ["# 2x2 analysis (convergence x correctness) — generated " + __import__("datetime").date.today().isoformat(), ""] for tag, r in results.items(): lines += [f"## {tag}", f"- npz: `{r['npz']}`, n={r['n']}, exact_acc={r['exact_acc']:.3f}", f"- late-drift def: {r['late_drift_def']}, Otsu tau(log10)={r['otsu_tau_log10']:.3f}, frac_converged={r['frac_converged']:.3f}", ""] lines.append("| cell | n | lam1 median | lam1 IQR | token_acc med |") lines.append("|---|---|---|---|---|") for k, s in r["cells"].items(): if s["n"] == 0: lines.append(f"| {k} | 0 | - | - | - |") else: lines.append(f"| {k} | {s['n']} | {s['lam1_median']:+.4f} | [{s['lam1_iqr'][0]:+.4f}, {s['lam1_iqr'][1]:+.4f}] | {s['token_acc_median']:.3f} |") c = r["contrasts"]; a = r["auc"]; m = r["mixture"] lines += ["", f"- mixture: wrong-that-converged = {m['wrong_that_converged']:.3f}; correct-that-nonconverged = {m['correct_that_nonconverged']:.3f}", f"- dlam1(correct-wrong): overall {c['dlam1_correct_minus_wrong_overall']:+.4f}; within-conv {c['dlam1_within_converged']:+.4f}; within-nonconv {c['dlam1_within_nonconverged']:+.4f}", f"- dlam1(wrong: conv - nonconv) = {c['dlam1_wrong_conv_minus_wrong_nonconv']:+.4f}", f"- AUC(-lam1 -> correct): overall {a['neg_lam1_predicts_correct_overall']:.3f}; within-conv {a['neg_lam1_predicts_correct_within_conv']:.3f}; within-nonconv {a['neg_lam1_predicts_correct_within_nonconv']:.3f}", f"- AUC(-log d_late -> correct) = {a['neg_logdrift_predicts_correct']:.3f}; AUC(-lam1 -> converged) = {a['neg_lam1_predicts_converged']:.3f}", ""] (HERE / "results.md").write_text("\n".join(lines)) print("wrote", HERE / "results.md") if __name__ == "__main__": main()