summaryrefslogtreecommitdiff
path: root/research/flossing/analysis_2x2/analyze_2x2.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-06-13 12:35:36 -0500
commit66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch)
treec29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analysis_2x2/analyze_2x2.py
rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipelineHEADmain
Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analysis_2x2/analyze_2x2.py')
-rw-r--r--research/flossing/analysis_2x2/analyze_2x2.py350
1 files changed, 350 insertions, 0 deletions
diff --git a/research/flossing/analysis_2x2/analyze_2x2.py b/research/flossing/analysis_2x2/analyze_2x2.py
new file mode 100644
index 0000000..315ffc6
--- /dev/null
+++ b/research/flossing/analysis_2x2/analyze_2x2.py
@@ -0,0 +1,350 @@
+"""2x2 analysis: (terminal convergence) x (answer correctness), per-example FTLE per cell.
+
+Inputs: existing diagnostic npz files produced by diagnose_hrm*.py / diagnose_trm_joint.py:
+ drift_zH/drift_zL (N,16) per-ACT-step state displacement norms,
+ lyap_spec (N,8) full-window joint FTLE spectrum, exact_correct (N,),
+ token_acc, halted_at, q_halt/q_continue (N,16), idx.
+
+Convergence metric (measurement choice, reported alongside robustness sweep):
+ d_late = mean(drift_zH[:, -4:]) (late-trajectory z_H velocity, ACT steps 13-16)
+ primary threshold tau = Otsu on log10(d_late) pooled per dataset;
+ sensitivity: tau swept over pooled percentiles 5..95.
+
+Cells: A = converged & correct, B = converged & wrong,
+ C = non-converged & correct, D = non-converged & wrong.
+
+Outputs (in this directory): results_<tag>.json, cells_<tag>.csv, sweep_<tag>.csv,
+ fig_<tag>_{drift_hist,lyap_by_cell,scatter,spectrum}.png, evolution_{hrm,trm}.{csv,png},
+ results.md (combined human-readable summary).
+
+Observational only: this script reports counts, distributions and rank statistics; it does
+not test mechanisms.
+"""
+from __future__ import annotations
+
+import json
+from pathlib import Path
+
+import matplotlib
+
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import numpy as np
+
+HERE = Path(__file__).resolve().parent
+FLOSS = HERE.parent
+
+LATE_K = 4 # ACT steps used for late drift
+
+
+def otsu_threshold(x: np.ndarray, nbins: int = 256) -> float:
+ h, edges = np.histogram(x, bins=nbins)
+ h = h.astype(np.float64)
+ centers = 0.5 * (edges[:-1] + edges[1:])
+ w = h.sum()
+ if w == 0:
+ return float(np.median(x))
+ p = h / w
+ omega = np.cumsum(p)
+ mu = np.cumsum(p * centers)
+ mu_t = mu[-1]
+ denom = omega * (1.0 - omega)
+ denom[denom <= 0] = np.nan
+ sigma_b2 = (mu_t * omega - mu) ** 2 / denom
+ k = np.nanargmax(sigma_b2)
+ return float(centers[k])
+
+
+def auc_rank(score: np.ndarray, label: np.ndarray) -> float:
+ """AUC of `score` for predicting label==1 (rank-based, ties averaged)."""
+ pos = score[label == 1]
+ neg = score[label == 0]
+ if len(pos) == 0 or len(neg) == 0:
+ return float("nan")
+ allv = np.concatenate([pos, neg])
+ order = np.argsort(allv, kind="mergesort")
+ ranks = np.empty_like(order, dtype=np.float64)
+ ranks[order] = np.arange(1, len(allv) + 1)
+ # average ranks for ties
+ sv = allv[order]
+ i = 0
+ while i < len(sv):
+ j = i
+ while j + 1 < len(sv) and sv[j + 1] == sv[i]:
+ j += 1
+ if j > i:
+ ranks[order[i : j + 1]] = ranks[order[i : j + 1]].mean()
+ i = j + 1
+ r_pos = ranks[: len(pos)].sum()
+ return float((r_pos - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg)))
+
+
+def cell_stats(lyap: np.ndarray, tok: np.ndarray, halted: np.ndarray, mask: np.ndarray) -> dict:
+ if mask.sum() == 0:
+ return {"n": 0}
+ l1 = lyap[mask, 0]
+ return {
+ "n": int(mask.sum()),
+ "lam1_median": float(np.median(l1)),
+ "lam1_mean": float(l1.mean()),
+ "lam1_iqr": [float(np.percentile(l1, 25)), float(np.percentile(l1, 75))],
+ "lam8_median": float(np.median(lyap[mask, -1])),
+ "spectrum_median": [float(np.median(lyap[mask, i])) for i in range(lyap.shape[1])],
+ "token_acc_median": float(np.median(tok[mask])),
+ "halted_at_median": float(np.median(halted[mask])),
+ }
+
+
+def analyze(npz_path: Path, tag: str, make_figs: bool = True) -> dict:
+ d = np.load(npz_path)
+ lyap = d["lyap_spec"].astype(np.float64)
+ correct = d["exact_correct"].astype(int)
+ tok = d["token_acc"].astype(np.float64)
+ halted = d["halted_at"].astype(np.float64)
+ drift_h = d["drift_zH"].astype(np.float64)
+ drift_l = d["drift_zL"].astype(np.float64)
+
+ d_late = drift_h[:, -LATE_K:].mean(axis=1)
+ d_late_l = drift_l[:, -LATE_K:].mean(axis=1)
+ logd = np.log10(np.clip(d_late, 1e-12, None))
+
+ tau = otsu_threshold(logd)
+ conv = logd < tau
+
+ cells = {
+ "A_conv_correct": (conv) & (correct == 1),
+ "B_conv_wrong": (conv) & (correct == 0),
+ "C_nonconv_correct": (~conv) & (correct == 1),
+ "D_nonconv_wrong": (~conv) & (correct == 0),
+ }
+
+ res = {
+ "npz": str(npz_path),
+ "n": int(len(correct)),
+ "exact_acc": float(correct.mean()),
+ "late_drift_def": f"mean(drift_zH[:, -{LATE_K}:])",
+ "otsu_tau_log10": tau,
+ "frac_converged": float(conv.mean()),
+ "cells": {k: cell_stats(lyap, tok, halted, m) for k, m in cells.items()},
+ "mixture": {
+ "wrong_that_converged": float(
+ cells["B_conv_wrong"].sum() / max((correct == 0).sum(), 1)
+ ),
+ "correct_that_nonconverged": float(
+ cells["C_nonconv_correct"].sum() / max((correct == 1).sum(), 1)
+ ),
+ },
+ "contrasts": {
+ "dlam1_correct_minus_wrong_overall": float(
+ np.median(lyap[correct == 1, 0]) - np.median(lyap[correct == 0, 0])
+ ),
+ "dlam1_within_converged": float(
+ np.median(lyap[cells["A_conv_correct"], 0]) - np.median(lyap[cells["B_conv_wrong"], 0])
+ )
+ if cells["B_conv_wrong"].sum() > 0 and cells["A_conv_correct"].sum() > 0
+ else float("nan"),
+ "dlam1_within_nonconverged": float(
+ np.median(lyap[cells["C_nonconv_correct"], 0]) - np.median(lyap[cells["D_nonconv_wrong"], 0])
+ )
+ if cells["C_nonconv_correct"].sum() > 0 and cells["D_nonconv_wrong"].sum() > 0
+ else float("nan"),
+ "dlam1_wrong_conv_minus_wrong_nonconv": float(
+ np.median(lyap[cells["B_conv_wrong"], 0]) - np.median(lyap[cells["D_nonconv_wrong"], 0])
+ )
+ if cells["B_conv_wrong"].sum() > 0 and cells["D_nonconv_wrong"].sum() > 0
+ else float("nan"),
+ },
+ "auc": {
+ "neg_lam1_predicts_correct_overall": auc_rank(-lyap[:, 0], correct),
+ "neg_lam1_predicts_correct_within_conv": auc_rank(-lyap[conv, 0], correct[conv]),
+ "neg_lam1_predicts_correct_within_nonconv": auc_rank(-lyap[~conv, 0], correct[~conv]),
+ "neg_logdrift_predicts_correct": auc_rank(-logd, correct),
+ "neg_lam1_predicts_converged": auc_rank(-lyap[:, 0], conv.astype(int)),
+ },
+ }
+
+ # threshold sensitivity sweep
+ sweep_rows = []
+ for pct in range(5, 96, 5):
+ t = np.percentile(logd, pct)
+ c = logd < t
+ row = {
+ "pct": pct,
+ "tau": float(t),
+ "nA": int((c & (correct == 1)).sum()),
+ "nB": int((c & (correct == 0)).sum()),
+ "nC": int((~c & (correct == 1)).sum()),
+ "nD": int((~c & (correct == 0)).sum()),
+ }
+ for nm, m in [
+ ("lam1_med_B", c & (correct == 0)),
+ ("lam1_med_D", ~c & (correct == 0)),
+ ]:
+ row[nm] = float(np.median(lyap[m, 0])) if m.sum() > 0 else float("nan")
+ sweep_rows.append(row)
+ sweep_csv = HERE / f"sweep_{tag}.csv"
+ with sweep_csv.open("w") as f:
+ keys = list(sweep_rows[0].keys())
+ f.write(",".join(keys) + "\n")
+ for r in sweep_rows:
+ f.write(",".join(str(r[k]) for k in keys) + "\n")
+
+ # per-cell csv
+ with (HERE / f"cells_{tag}.csv").open("w") as f:
+ f.write("cell,n,lam1_median,lam1_mean,lam1_q25,lam1_q75,lam8_median,token_acc_median,halted_at_median\n")
+ for k, m in cells.items():
+ s = res["cells"][k]
+ if s["n"] == 0:
+ f.write(f"{k},0,,,,,,,\n")
+ continue
+ f.write(
+ f"{k},{s['n']},{s['lam1_median']:.6f},{s['lam1_mean']:.6f},"
+ f"{s['lam1_iqr'][0]:.6f},{s['lam1_iqr'][1]:.6f},{s['lam8_median']:.6f},"
+ f"{s['token_acc_median']:.4f},{s['halted_at_median']:.1f}\n"
+ )
+
+ if make_figs:
+ colors = {"A_conv_correct": "tab:green", "B_conv_wrong": "tab:orange",
+ "C_nonconv_correct": "tab:blue", "D_nonconv_wrong": "tab:red"}
+
+ fig, ax = plt.subplots(figsize=(6, 4))
+ bins = np.linspace(logd.min(), logd.max(), 60)
+ ax.hist(logd[correct == 1], bins=bins, alpha=0.55, label=f"correct (n={int(correct.sum())})", color="tab:green")
+ ax.hist(logd[correct == 0], bins=bins, alpha=0.55, label=f"wrong (n={int((1-correct).sum())})", color="tab:red")
+ ax.axvline(tau, color="k", ls="--", lw=1, label=f"Otsu tau={tau:.2f}")
+ ax.set_xlabel("log10 late drift_zH (steps -4:)"); ax.set_ylabel("count")
+ ax.set_title(f"{tag}: late-drift distribution by correctness"); ax.legend(fontsize=8)
+ fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_drift_hist.png", dpi=150); plt.close(fig)
+
+ fig, ax = plt.subplots(figsize=(6.5, 4))
+ for i, (k, m) in enumerate(cells.items()):
+ if m.sum() == 0:
+ continue
+ y = lyap[m, 0]
+ x = np.full(y.shape, i) + (np.random.default_rng(0).uniform(-0.18, 0.18, y.shape))
+ ax.plot(x, y, ".", ms=3, alpha=0.35, color=colors[k])
+ ax.hlines(np.median(y), i - 0.28, i + 0.28, color=colors[k], lw=2.5)
+ ax.set_xticks(range(4)); ax.set_xticklabels([f"{k}\n(n={int(m.sum())})" for k, m in cells.items()], fontsize=7)
+ ax.set_ylabel("lambda_1 (full-window FTLE)"); ax.axhline(0, color="gray", lw=0.6)
+ ax.set_title(f"{tag}: lambda_1 by 2x2 cell")
+ fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_lyap_by_cell.png", dpi=150); plt.close(fig)
+
+ fig, ax = plt.subplots(figsize=(6, 4.5))
+ ax.scatter(logd[correct == 1], lyap[correct == 1, 0], s=5, alpha=0.4, c="tab:green", label="correct")
+ ax.scatter(logd[correct == 0], lyap[correct == 0, 0], s=5, alpha=0.4, c="tab:red", label="wrong")
+ ax.axvline(tau, color="k", ls="--", lw=1); ax.axhline(0, color="gray", lw=0.6)
+ ax.set_xlabel("log10 late drift_zH"); ax.set_ylabel("lambda_1")
+ ax.set_title(f"{tag}: drift vs lambda_1"); ax.legend(fontsize=8)
+ fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_scatter.png", dpi=150); plt.close(fig)
+
+ fig, ax = plt.subplots(figsize=(6, 4))
+ for k, m in cells.items():
+ if m.sum() < 3:
+ continue
+ ax.plot(range(1, lyap.shape[1] + 1), [np.median(lyap[m, i]) for i in range(lyap.shape[1])],
+ "o-", ms=4, label=f"{k} (n={int(m.sum())})", color=colors[k])
+ ax.axhline(0, color="gray", lw=0.6)
+ ax.set_xlabel("exponent index"); ax.set_ylabel("median lambda_i")
+ ax.set_title(f"{tag}: median FTLE spectrum per cell"); ax.legend(fontsize=7)
+ fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_spectrum.png", dpi=150); plt.close(fig)
+
+ # secondary observables
+ res["aux"] = {
+ "late_drift_zL_corr_with_zH_log": float(np.corrcoef(logd, np.log10(np.clip(d_late_l, 1e-12, None)))[0, 1]),
+ "q_halt_final_median_by_cell": {
+ k: float(np.median(d["q_halt"][m, -1])) if m.sum() > 0 else float("nan") for k, m in cells.items()
+ },
+ }
+ (HERE / f"results_{tag}.json").write_text(json.dumps(res, indent=2))
+ return res
+
+
+def evolution(series: list[tuple[str, Path]], out_tag: str) -> None:
+ rows = []
+ for label, p in series:
+ if not p.exists():
+ continue
+ d = np.load(p)
+ lyap = d["lyap_spec"].astype(np.float64)
+ correct = d["exact_correct"].astype(int)
+ logd = np.log10(np.clip(d["drift_zH"][:, -LATE_K:].mean(axis=1), 1e-12, None))
+ tau = otsu_threshold(logd)
+ conv = logd < tau
+ row = dict(step=label, acc=float(correct.mean()), tau=tau,
+ fA=float(((conv) & (correct == 1)).mean()), fB=float(((conv) & (correct == 0)).mean()),
+ fC=float(((~conv) & (correct == 1)).mean()), fD=float(((~conv) & (correct == 0)).mean()))
+ for nm, m in [("l1A", (conv) & (correct == 1)), ("l1B", (conv) & (correct == 0)),
+ ("l1C", (~conv) & (correct == 1)), ("l1D", (~conv) & (correct == 0))]:
+ row[nm] = float(np.median(lyap[m, 0])) if m.sum() > 2 else float("nan")
+ rows.append(row)
+ if not rows:
+ return
+ keys = list(rows[0].keys())
+ with (HERE / f"evolution_{out_tag}.csv").open("w") as f:
+ f.write(",".join(keys) + "\n")
+ for r in rows:
+ f.write(",".join(str(r[k]) for k in keys) + "\n")
+ fig, axes = plt.subplots(1, 2, figsize=(11, 4))
+ xs = range(len(rows))
+ for nm, c in [("fA", "tab:green"), ("fB", "tab:orange"), ("fC", "tab:blue"), ("fD", "tab:red")]:
+ axes[0].plot(xs, [r[nm] for r in rows], "o-", label=nm, color=c)
+ axes[0].set_xticks(list(xs)); axes[0].set_xticklabels([r["step"] for r in rows], rotation=45, fontsize=7)
+ axes[0].set_ylabel("cell fraction"); axes[0].legend(fontsize=8); axes[0].set_title(f"{out_tag}: cell fractions")
+ for nm, c in [("l1A", "tab:green"), ("l1B", "tab:orange"), ("l1C", "tab:blue"), ("l1D", "tab:red")]:
+ axes[1].plot(xs, [r[nm] for r in rows], "o-", label=nm, color=c)
+ axes[1].axhline(0, color="gray", lw=0.6)
+ axes[1].set_xticks(list(xs)); axes[1].set_xticklabels([r["step"] for r in rows], rotation=45, fontsize=7)
+ axes[1].set_ylabel("median lambda_1"); axes[1].legend(fontsize=8); axes[1].set_title(f"{out_tag}: per-cell lambda_1")
+ fig.tight_layout(); fig.savefig(HERE / f"evolution_{out_tag}.png", dpi=150); plt.close(fig)
+
+
+def main() -> None:
+ results = {}
+ primary = [
+ ("hrm26040_n8192", FLOSS / "diag_8k.npz"),
+ ("trm_singleGPU_step260410_n512", FLOSS / "diag_trm_singleGPU_step260410_512.npz"),
+ ("trm_singleGPU_step130205_n512", FLOSS / "diag_trm_singleGPU_step130205_512.npz"),
+ ("trm_step13020_n512", FLOSS / "diag_trm_step13020_512.npz"),
+ ]
+ for tag, p in primary:
+ if p.exists():
+ results[tag] = analyze(p, tag)
+ print(f"[done] {tag}")
+
+ evolution(
+ [(f"{s}", FLOSS / f"diag_hrm_step_{s}_512.npz") for s in
+ [2604, 5208, 7812, 10416, 13020, 15624, 18228, 20832, 23436, 26040]],
+ "hrm",
+ )
+ evolution(
+ [(f"{s}", FLOSS / f"diag_trm_singleGPU_step{s}_512.npz") for s in
+ [26041, 52082, 78123, 104164, 130205, 156246, 182287, 208328, 234369, 260410]],
+ "trm",
+ )
+
+ # combined human-readable summary
+ lines = ["# 2x2 analysis (convergence x correctness) — generated " + __import__("datetime").date.today().isoformat(), ""]
+ for tag, r in results.items():
+ lines += [f"## {tag}", f"- npz: `{r['npz']}`, n={r['n']}, exact_acc={r['exact_acc']:.3f}",
+ f"- late-drift def: {r['late_drift_def']}, Otsu tau(log10)={r['otsu_tau_log10']:.3f}, frac_converged={r['frac_converged']:.3f}", ""]
+ lines.append("| cell | n | lam1 median | lam1 IQR | token_acc med |")
+ lines.append("|---|---|---|---|---|")
+ for k, s in r["cells"].items():
+ if s["n"] == 0:
+ lines.append(f"| {k} | 0 | - | - | - |")
+ else:
+ lines.append(f"| {k} | {s['n']} | {s['lam1_median']:+.4f} | [{s['lam1_iqr'][0]:+.4f}, {s['lam1_iqr'][1]:+.4f}] | {s['token_acc_median']:.3f} |")
+ c = r["contrasts"]; a = r["auc"]; m = r["mixture"]
+ lines += ["",
+ f"- mixture: wrong-that-converged = {m['wrong_that_converged']:.3f}; correct-that-nonconverged = {m['correct_that_nonconverged']:.3f}",
+ f"- dlam1(correct-wrong): overall {c['dlam1_correct_minus_wrong_overall']:+.4f}; within-conv {c['dlam1_within_converged']:+.4f}; within-nonconv {c['dlam1_within_nonconverged']:+.4f}",
+ f"- dlam1(wrong: conv - nonconv) = {c['dlam1_wrong_conv_minus_wrong_nonconv']:+.4f}",
+ f"- AUC(-lam1 -> correct): overall {a['neg_lam1_predicts_correct_overall']:.3f}; within-conv {a['neg_lam1_predicts_correct_within_conv']:.3f}; within-nonconv {a['neg_lam1_predicts_correct_within_nonconv']:.3f}",
+ f"- AUC(-log d_late -> correct) = {a['neg_logdrift_predicts_correct']:.3f}; AUC(-lam1 -> converged) = {a['neg_lam1_predicts_converged']:.3f}",
+ ""]
+ (HERE / "results.md").write_text("\n".join(lines))
+ print("wrote", HERE / "results.md")
+
+
+if __name__ == "__main__":
+ main()