diff options
Diffstat (limited to 'research/flossing/analysis_2x2/analyze_2x2.py')
| -rw-r--r-- | research/flossing/analysis_2x2/analyze_2x2.py | 350 |
1 files changed, 350 insertions, 0 deletions
diff --git a/research/flossing/analysis_2x2/analyze_2x2.py b/research/flossing/analysis_2x2/analyze_2x2.py new file mode 100644 index 0000000..315ffc6 --- /dev/null +++ b/research/flossing/analysis_2x2/analyze_2x2.py @@ -0,0 +1,350 @@ +"""2x2 analysis: (terminal convergence) x (answer correctness), per-example FTLE per cell. + +Inputs: existing diagnostic npz files produced by diagnose_hrm*.py / diagnose_trm_joint.py: + drift_zH/drift_zL (N,16) per-ACT-step state displacement norms, + lyap_spec (N,8) full-window joint FTLE spectrum, exact_correct (N,), + token_acc, halted_at, q_halt/q_continue (N,16), idx. + +Convergence metric (measurement choice, reported alongside robustness sweep): + d_late = mean(drift_zH[:, -4:]) (late-trajectory z_H velocity, ACT steps 13-16) + primary threshold tau = Otsu on log10(d_late) pooled per dataset; + sensitivity: tau swept over pooled percentiles 5..95. + +Cells: A = converged & correct, B = converged & wrong, + C = non-converged & correct, D = non-converged & wrong. + +Outputs (in this directory): results_<tag>.json, cells_<tag>.csv, sweep_<tag>.csv, + fig_<tag>_{drift_hist,lyap_by_cell,scatter,spectrum}.png, evolution_{hrm,trm}.{csv,png}, + results.md (combined human-readable summary). + +Observational only: this script reports counts, distributions and rank statistics; it does +not test mechanisms. +""" +from __future__ import annotations + +import json +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +HERE = Path(__file__).resolve().parent +FLOSS = HERE.parent + +LATE_K = 4 # ACT steps used for late drift + + +def otsu_threshold(x: np.ndarray, nbins: int = 256) -> float: + h, edges = np.histogram(x, bins=nbins) + h = h.astype(np.float64) + centers = 0.5 * (edges[:-1] + edges[1:]) + w = h.sum() + if w == 0: + return float(np.median(x)) + p = h / w + omega = np.cumsum(p) + mu = np.cumsum(p * centers) + mu_t = mu[-1] + denom = omega * (1.0 - omega) + denom[denom <= 0] = np.nan + sigma_b2 = (mu_t * omega - mu) ** 2 / denom + k = np.nanargmax(sigma_b2) + return float(centers[k]) + + +def auc_rank(score: np.ndarray, label: np.ndarray) -> float: + """AUC of `score` for predicting label==1 (rank-based, ties averaged).""" + pos = score[label == 1] + neg = score[label == 0] + if len(pos) == 0 or len(neg) == 0: + return float("nan") + allv = np.concatenate([pos, neg]) + order = np.argsort(allv, kind="mergesort") + ranks = np.empty_like(order, dtype=np.float64) + ranks[order] = np.arange(1, len(allv) + 1) + # average ranks for ties + sv = allv[order] + i = 0 + while i < len(sv): + j = i + while j + 1 < len(sv) and sv[j + 1] == sv[i]: + j += 1 + if j > i: + ranks[order[i : j + 1]] = ranks[order[i : j + 1]].mean() + i = j + 1 + r_pos = ranks[: len(pos)].sum() + return float((r_pos - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg))) + + +def cell_stats(lyap: np.ndarray, tok: np.ndarray, halted: np.ndarray, mask: np.ndarray) -> dict: + if mask.sum() == 0: + return {"n": 0} + l1 = lyap[mask, 0] + return { + "n": int(mask.sum()), + "lam1_median": float(np.median(l1)), + "lam1_mean": float(l1.mean()), + "lam1_iqr": [float(np.percentile(l1, 25)), float(np.percentile(l1, 75))], + "lam8_median": float(np.median(lyap[mask, -1])), + "spectrum_median": [float(np.median(lyap[mask, i])) for i in range(lyap.shape[1])], + "token_acc_median": float(np.median(tok[mask])), + "halted_at_median": float(np.median(halted[mask])), + } + + +def analyze(npz_path: Path, tag: str, make_figs: bool = True) -> dict: + d = np.load(npz_path) + lyap = d["lyap_spec"].astype(np.float64) + correct = d["exact_correct"].astype(int) + tok = d["token_acc"].astype(np.float64) + halted = d["halted_at"].astype(np.float64) + drift_h = d["drift_zH"].astype(np.float64) + drift_l = d["drift_zL"].astype(np.float64) + + d_late = drift_h[:, -LATE_K:].mean(axis=1) + d_late_l = drift_l[:, -LATE_K:].mean(axis=1) + logd = np.log10(np.clip(d_late, 1e-12, None)) + + tau = otsu_threshold(logd) + conv = logd < tau + + cells = { + "A_conv_correct": (conv) & (correct == 1), + "B_conv_wrong": (conv) & (correct == 0), + "C_nonconv_correct": (~conv) & (correct == 1), + "D_nonconv_wrong": (~conv) & (correct == 0), + } + + res = { + "npz": str(npz_path), + "n": int(len(correct)), + "exact_acc": float(correct.mean()), + "late_drift_def": f"mean(drift_zH[:, -{LATE_K}:])", + "otsu_tau_log10": tau, + "frac_converged": float(conv.mean()), + "cells": {k: cell_stats(lyap, tok, halted, m) for k, m in cells.items()}, + "mixture": { + "wrong_that_converged": float( + cells["B_conv_wrong"].sum() / max((correct == 0).sum(), 1) + ), + "correct_that_nonconverged": float( + cells["C_nonconv_correct"].sum() / max((correct == 1).sum(), 1) + ), + }, + "contrasts": { + "dlam1_correct_minus_wrong_overall": float( + np.median(lyap[correct == 1, 0]) - np.median(lyap[correct == 0, 0]) + ), + "dlam1_within_converged": float( + np.median(lyap[cells["A_conv_correct"], 0]) - np.median(lyap[cells["B_conv_wrong"], 0]) + ) + if cells["B_conv_wrong"].sum() > 0 and cells["A_conv_correct"].sum() > 0 + else float("nan"), + "dlam1_within_nonconverged": float( + np.median(lyap[cells["C_nonconv_correct"], 0]) - np.median(lyap[cells["D_nonconv_wrong"], 0]) + ) + if cells["C_nonconv_correct"].sum() > 0 and cells["D_nonconv_wrong"].sum() > 0 + else float("nan"), + "dlam1_wrong_conv_minus_wrong_nonconv": float( + np.median(lyap[cells["B_conv_wrong"], 0]) - np.median(lyap[cells["D_nonconv_wrong"], 0]) + ) + if cells["B_conv_wrong"].sum() > 0 and cells["D_nonconv_wrong"].sum() > 0 + else float("nan"), + }, + "auc": { + "neg_lam1_predicts_correct_overall": auc_rank(-lyap[:, 0], correct), + "neg_lam1_predicts_correct_within_conv": auc_rank(-lyap[conv, 0], correct[conv]), + "neg_lam1_predicts_correct_within_nonconv": auc_rank(-lyap[~conv, 0], correct[~conv]), + "neg_logdrift_predicts_correct": auc_rank(-logd, correct), + "neg_lam1_predicts_converged": auc_rank(-lyap[:, 0], conv.astype(int)), + }, + } + + # threshold sensitivity sweep + sweep_rows = [] + for pct in range(5, 96, 5): + t = np.percentile(logd, pct) + c = logd < t + row = { + "pct": pct, + "tau": float(t), + "nA": int((c & (correct == 1)).sum()), + "nB": int((c & (correct == 0)).sum()), + "nC": int((~c & (correct == 1)).sum()), + "nD": int((~c & (correct == 0)).sum()), + } + for nm, m in [ + ("lam1_med_B", c & (correct == 0)), + ("lam1_med_D", ~c & (correct == 0)), + ]: + row[nm] = float(np.median(lyap[m, 0])) if m.sum() > 0 else float("nan") + sweep_rows.append(row) + sweep_csv = HERE / f"sweep_{tag}.csv" + with sweep_csv.open("w") as f: + keys = list(sweep_rows[0].keys()) + f.write(",".join(keys) + "\n") + for r in sweep_rows: + f.write(",".join(str(r[k]) for k in keys) + "\n") + + # per-cell csv + with (HERE / f"cells_{tag}.csv").open("w") as f: + f.write("cell,n,lam1_median,lam1_mean,lam1_q25,lam1_q75,lam8_median,token_acc_median,halted_at_median\n") + for k, m in cells.items(): + s = res["cells"][k] + if s["n"] == 0: + f.write(f"{k},0,,,,,,,\n") + continue + f.write( + f"{k},{s['n']},{s['lam1_median']:.6f},{s['lam1_mean']:.6f}," + f"{s['lam1_iqr'][0]:.6f},{s['lam1_iqr'][1]:.6f},{s['lam8_median']:.6f}," + f"{s['token_acc_median']:.4f},{s['halted_at_median']:.1f}\n" + ) + + if make_figs: + colors = {"A_conv_correct": "tab:green", "B_conv_wrong": "tab:orange", + "C_nonconv_correct": "tab:blue", "D_nonconv_wrong": "tab:red"} + + fig, ax = plt.subplots(figsize=(6, 4)) + bins = np.linspace(logd.min(), logd.max(), 60) + ax.hist(logd[correct == 1], bins=bins, alpha=0.55, label=f"correct (n={int(correct.sum())})", color="tab:green") + ax.hist(logd[correct == 0], bins=bins, alpha=0.55, label=f"wrong (n={int((1-correct).sum())})", color="tab:red") + ax.axvline(tau, color="k", ls="--", lw=1, label=f"Otsu tau={tau:.2f}") + ax.set_xlabel("log10 late drift_zH (steps -4:)"); ax.set_ylabel("count") + ax.set_title(f"{tag}: late-drift distribution by correctness"); ax.legend(fontsize=8) + fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_drift_hist.png", dpi=150); plt.close(fig) + + fig, ax = plt.subplots(figsize=(6.5, 4)) + for i, (k, m) in enumerate(cells.items()): + if m.sum() == 0: + continue + y = lyap[m, 0] + x = np.full(y.shape, i) + (np.random.default_rng(0).uniform(-0.18, 0.18, y.shape)) + ax.plot(x, y, ".", ms=3, alpha=0.35, color=colors[k]) + ax.hlines(np.median(y), i - 0.28, i + 0.28, color=colors[k], lw=2.5) + ax.set_xticks(range(4)); ax.set_xticklabels([f"{k}\n(n={int(m.sum())})" for k, m in cells.items()], fontsize=7) + ax.set_ylabel("lambda_1 (full-window FTLE)"); ax.axhline(0, color="gray", lw=0.6) + ax.set_title(f"{tag}: lambda_1 by 2x2 cell") + fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_lyap_by_cell.png", dpi=150); plt.close(fig) + + fig, ax = plt.subplots(figsize=(6, 4.5)) + ax.scatter(logd[correct == 1], lyap[correct == 1, 0], s=5, alpha=0.4, c="tab:green", label="correct") + ax.scatter(logd[correct == 0], lyap[correct == 0, 0], s=5, alpha=0.4, c="tab:red", label="wrong") + ax.axvline(tau, color="k", ls="--", lw=1); ax.axhline(0, color="gray", lw=0.6) + ax.set_xlabel("log10 late drift_zH"); ax.set_ylabel("lambda_1") + ax.set_title(f"{tag}: drift vs lambda_1"); ax.legend(fontsize=8) + fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_scatter.png", dpi=150); plt.close(fig) + + fig, ax = plt.subplots(figsize=(6, 4)) + for k, m in cells.items(): + if m.sum() < 3: + continue + ax.plot(range(1, lyap.shape[1] + 1), [np.median(lyap[m, i]) for i in range(lyap.shape[1])], + "o-", ms=4, label=f"{k} (n={int(m.sum())})", color=colors[k]) + ax.axhline(0, color="gray", lw=0.6) + ax.set_xlabel("exponent index"); ax.set_ylabel("median lambda_i") + ax.set_title(f"{tag}: median FTLE spectrum per cell"); ax.legend(fontsize=7) + fig.tight_layout(); fig.savefig(HERE / f"fig_{tag}_spectrum.png", dpi=150); plt.close(fig) + + # secondary observables + res["aux"] = { + "late_drift_zL_corr_with_zH_log": float(np.corrcoef(logd, np.log10(np.clip(d_late_l, 1e-12, None)))[0, 1]), + "q_halt_final_median_by_cell": { + k: float(np.median(d["q_halt"][m, -1])) if m.sum() > 0 else float("nan") for k, m in cells.items() + }, + } + (HERE / f"results_{tag}.json").write_text(json.dumps(res, indent=2)) + return res + + +def evolution(series: list[tuple[str, Path]], out_tag: str) -> None: + rows = [] + for label, p in series: + if not p.exists(): + continue + d = np.load(p) + lyap = d["lyap_spec"].astype(np.float64) + correct = d["exact_correct"].astype(int) + logd = np.log10(np.clip(d["drift_zH"][:, -LATE_K:].mean(axis=1), 1e-12, None)) + tau = otsu_threshold(logd) + conv = logd < tau + row = dict(step=label, acc=float(correct.mean()), tau=tau, + fA=float(((conv) & (correct == 1)).mean()), fB=float(((conv) & (correct == 0)).mean()), + fC=float(((~conv) & (correct == 1)).mean()), fD=float(((~conv) & (correct == 0)).mean())) + for nm, m in [("l1A", (conv) & (correct == 1)), ("l1B", (conv) & (correct == 0)), + ("l1C", (~conv) & (correct == 1)), ("l1D", (~conv) & (correct == 0))]: + row[nm] = float(np.median(lyap[m, 0])) if m.sum() > 2 else float("nan") + rows.append(row) + if not rows: + return + keys = list(rows[0].keys()) + with (HERE / f"evolution_{out_tag}.csv").open("w") as f: + f.write(",".join(keys) + "\n") + for r in rows: + f.write(",".join(str(r[k]) for k in keys) + "\n") + fig, axes = plt.subplots(1, 2, figsize=(11, 4)) + xs = range(len(rows)) + for nm, c in [("fA", "tab:green"), ("fB", "tab:orange"), ("fC", "tab:blue"), ("fD", "tab:red")]: + axes[0].plot(xs, [r[nm] for r in rows], "o-", label=nm, color=c) + axes[0].set_xticks(list(xs)); axes[0].set_xticklabels([r["step"] for r in rows], rotation=45, fontsize=7) + axes[0].set_ylabel("cell fraction"); axes[0].legend(fontsize=8); axes[0].set_title(f"{out_tag}: cell fractions") + for nm, c in [("l1A", "tab:green"), ("l1B", "tab:orange"), ("l1C", "tab:blue"), ("l1D", "tab:red")]: + axes[1].plot(xs, [r[nm] for r in rows], "o-", label=nm, color=c) + axes[1].axhline(0, color="gray", lw=0.6) + axes[1].set_xticks(list(xs)); axes[1].set_xticklabels([r["step"] for r in rows], rotation=45, fontsize=7) + axes[1].set_ylabel("median lambda_1"); axes[1].legend(fontsize=8); axes[1].set_title(f"{out_tag}: per-cell lambda_1") + fig.tight_layout(); fig.savefig(HERE / f"evolution_{out_tag}.png", dpi=150); plt.close(fig) + + +def main() -> None: + results = {} + primary = [ + ("hrm26040_n8192", FLOSS / "diag_8k.npz"), + ("trm_singleGPU_step260410_n512", FLOSS / "diag_trm_singleGPU_step260410_512.npz"), + ("trm_singleGPU_step130205_n512", FLOSS / "diag_trm_singleGPU_step130205_512.npz"), + ("trm_step13020_n512", FLOSS / "diag_trm_step13020_512.npz"), + ] + for tag, p in primary: + if p.exists(): + results[tag] = analyze(p, tag) + print(f"[done] {tag}") + + evolution( + [(f"{s}", FLOSS / f"diag_hrm_step_{s}_512.npz") for s in + [2604, 5208, 7812, 10416, 13020, 15624, 18228, 20832, 23436, 26040]], + "hrm", + ) + evolution( + [(f"{s}", FLOSS / f"diag_trm_singleGPU_step{s}_512.npz") for s in + [26041, 52082, 78123, 104164, 130205, 156246, 182287, 208328, 234369, 260410]], + "trm", + ) + + # combined human-readable summary + lines = ["# 2x2 analysis (convergence x correctness) — generated " + __import__("datetime").date.today().isoformat(), ""] + for tag, r in results.items(): + lines += [f"## {tag}", f"- npz: `{r['npz']}`, n={r['n']}, exact_acc={r['exact_acc']:.3f}", + f"- late-drift def: {r['late_drift_def']}, Otsu tau(log10)={r['otsu_tau_log10']:.3f}, frac_converged={r['frac_converged']:.3f}", ""] + lines.append("| cell | n | lam1 median | lam1 IQR | token_acc med |") + lines.append("|---|---|---|---|---|") + for k, s in r["cells"].items(): + if s["n"] == 0: + lines.append(f"| {k} | 0 | - | - | - |") + else: + lines.append(f"| {k} | {s['n']} | {s['lam1_median']:+.4f} | [{s['lam1_iqr'][0]:+.4f}, {s['lam1_iqr'][1]:+.4f}] | {s['token_acc_median']:.3f} |") + c = r["contrasts"]; a = r["auc"]; m = r["mixture"] + lines += ["", + f"- mixture: wrong-that-converged = {m['wrong_that_converged']:.3f}; correct-that-nonconverged = {m['correct_that_nonconverged']:.3f}", + f"- dlam1(correct-wrong): overall {c['dlam1_correct_minus_wrong_overall']:+.4f}; within-conv {c['dlam1_within_converged']:+.4f}; within-nonconv {c['dlam1_within_nonconverged']:+.4f}", + f"- dlam1(wrong: conv - nonconv) = {c['dlam1_wrong_conv_minus_wrong_nonconv']:+.4f}", + f"- AUC(-lam1 -> correct): overall {a['neg_lam1_predicts_correct_overall']:.3f}; within-conv {a['neg_lam1_predicts_correct_within_conv']:.3f}; within-nonconv {a['neg_lam1_predicts_correct_within_nonconv']:.3f}", + f"- AUC(-log d_late -> correct) = {a['neg_logdrift_predicts_correct']:.3f}; AUC(-lam1 -> converged) = {a['neg_lam1_predicts_converged']:.3f}", + ""] + (HERE / "results.md").write_text("\n".join(lines)) + print("wrote", HERE / "results.md") + + +if __name__ == "__main__": + main() |
