diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-06-13 12:35:36 -0500 |
| commit | 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a (patch) | |
| tree | c29cba61124018755a19b02c9d33e3ad5f2e05cc /research/flossing/analysis_2x2/offline_followups.py | |
Curated export for clone-and-run Maze training (2x A6000) + diagnostics.
trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible).
Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Diffstat (limited to 'research/flossing/analysis_2x2/offline_followups.py')
| -rw-r--r-- | research/flossing/analysis_2x2/offline_followups.py | 229 |
1 files changed, 229 insertions, 0 deletions
diff --git a/research/flossing/analysis_2x2/offline_followups.py b/research/flossing/analysis_2x2/offline_followups.py new file mode 100644 index 0000000..8101acb --- /dev/null +++ b/research/flossing/analysis_2x2/offline_followups.py @@ -0,0 +1,229 @@ +"""Offline follow-ups to the 2x2 analysis (no GPU): + +1. Residual outcome signal within the unsettled stratum (HRM diag_8k primary, + TRM official @58590 secondary): per-cell drift profiles over the 16 ACT steps, + end-of-window drift slope, q_halt trajectories, halted_at, lambda spectra with a + STRICT in-band threshold, and per-drift-decile AUC(lambda1 -> correct) within the + unsettled stratum (does lambda1 add signal beyond drift level?). +2. Per-example profile of the strict-band settled-but-wrong examples (HRM, n~21). +3. Difficulty control: #givens per puzzle (input tokens != 1) joined via idx; + lambda1 ~ givens rank correlation overall/within outcome, and per-givens-bin + AUC(-lambda1 -> correct). + +Observational only. Outputs to analysis_2x2/offline_followups/. +""" +from __future__ import annotations + +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np + +HERE = Path(__file__).resolve().parent +FLOSS = HERE.parent +OUT = HERE / "offline_followups" +OUT.mkdir(exist_ok=True) + +DATA_TEST_INPUTS = Path("/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000/test/all__inputs.npy") + +CELL_COLORS = {"A": "tab:green", "B": "tab:orange", "C": "tab:blue", "D": "tab:red"} + + +def auc_rank(score: np.ndarray, label: np.ndarray) -> float: + pos, neg = score[label == 1], score[label == 0] + if len(pos) == 0 or len(neg) == 0: + return float("nan") + allv = np.concatenate([pos, neg]) + order = np.argsort(allv, kind="mergesort") + ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1) + sv = allv[order]; i = 0 + while i < len(sv): + j = i + while j + 1 < len(sv) and sv[j + 1] == sv[i]: + j += 1 + if j > i: + ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean() + i = j + 1 + return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg))) + + +def spearman(a: np.ndarray, b: np.ndarray) -> float: + ra = np.argsort(np.argsort(a)).astype(float) + rb = np.argsort(np.argsort(b)).astype(float) + return float(np.corrcoef(ra, rb)[0, 1]) + + +def load(npz_path: Path, strict_pct: float): + d = np.load(npz_path) + out = {k: d[k] for k in d.files} + out["logd_late"] = np.log10(np.clip(out["drift_zH"][:, -4:].mean(1), 1e-12, None)) + out["tau_strict"] = float(np.percentile(out["logd_late"], strict_pct)) + conv = out["logd_late"] < out["tau_strict"] + c = out["exact_correct"].astype(int) + out["cells"] = { + "A": conv & (c == 1), "B": conv & (c == 0), + "C": (~conv) & (c == 1), "D": (~conv) & (c == 0), + } + return out + + +def givens_for(idx: np.ndarray) -> np.ndarray: + inputs = np.load(DATA_TEST_INPUTS, mmap_mode="r") + return np.array([(inputs[i] != 1).sum() for i in idx]) + + +def drift_profiles_fig(ds, tag, lines): + fig, axes = plt.subplots(1, 2, figsize=(11, 4)) + steps = np.arange(1, ds["drift_zH"].shape[1] + 1) + for nm, m in ds["cells"].items(): + if m.sum() < 3: + continue + med = np.median(ds["drift_zH"][m], axis=0) + q1 = np.percentile(ds["drift_zH"][m], 25, axis=0) + q3 = np.percentile(ds["drift_zH"][m], 75, axis=0) + axes[0].plot(steps, med, "o-", ms=3, color=CELL_COLORS[nm], label=f"{nm} (n={int(m.sum())})") + axes[0].fill_between(steps, q1, q3, color=CELL_COLORS[nm], alpha=0.15) + qm = np.median(ds["q_halt"][m], axis=0) + axes[1].plot(steps, qm, "o-", ms=3, color=CELL_COLORS[nm], label=nm) + axes[0].set_yscale("log"); axes[0].set_xlabel("ACT step"); axes[0].set_ylabel("drift_zH (median, IQR)") + axes[0].legend(fontsize=8); axes[0].set_title(f"{tag}: drift profiles per cell") + axes[1].set_xlabel("ACT step"); axes[1].set_ylabel("q_halt (median)"); axes[1].axhline(0, color="gray", lw=0.6) + axes[1].set_title(f"{tag}: q_halt per cell"); axes[1].legend(fontsize=8) + fig.tight_layout(); fig.savefig(OUT / f"fig_{tag}_profiles.png", dpi=150); plt.close(fig) + + # end-of-window slope: log10 mean(drift[13:16]) - log10 mean(drift[9:12]) + slope = (np.log10(np.clip(ds["drift_zH"][:, 12:16].mean(1), 1e-12, None)) + - np.log10(np.clip(ds["drift_zH"][:, 8:12].mean(1), 1e-12, None))) + lines.append(f"\n### {tag}: end-of-window drift slope (log10 steps13-16 vs 9-12; <0 = still descending)") + for nm, m in ds["cells"].items(): + if m.sum() == 0: + lines.append(f"- {nm}: n=0") + continue + lines.append(f"- {nm}: n={int(m.sum())}, slope median {np.median(slope[m]):+.4f}, " + f"IQR [{np.percentile(slope[m],25):+.4f}, {np.percentile(slope[m],75):+.4f}], " + f"frac still descending (<-0.01): {float((slope[m] < -0.01).mean()):.2f}") + return slope + + +def main() -> None: + lines = ["# Offline follow-ups (no GPU) — 2026-06-11", + "", + "Strict in-band thresholds: HRM pct45 of pooled log10 late-drift; TRM pct60 (band edge; B=0 regardless).", + "All numbers observational; within-dataset comparisons only."] + + # ---------- HRM diag_8k ---------- + hrm = load(FLOSS / "diag_8k.npz", strict_pct=45) + tag = "hrm26040_n8192_strict" + lines.append(f"\n## HRM @26040 (n=8192), strict tau(log10)={hrm['tau_strict']:.4f}") + lines.append("| cell | n | lam1 med | lam8 med | token_acc med | halted_at med | q_halt_final med | givens med |") + lines.append("|---|---|---|---|---|---|---|---|") + g_hrm = givens_for(hrm["idx"]) + for nm, m in hrm["cells"].items(): + if m.sum() == 0: + lines.append(f"| {nm} | 0 | | | | | | |") + continue + lines.append( + f"| {nm} | {int(m.sum())} | {np.median(hrm['lyap_spec'][m,0]):+.4f} | {np.median(hrm['lyap_spec'][m,-1]):+.4f} " + f"| {np.median(hrm['token_acc'][m]):.3f} | {np.median(hrm['halted_at'][m]):.0f} " + f"| {np.median(hrm['q_halt'][m,-1]):+.2f} | {np.median(g_hrm[m]):.0f} |") + + slope = drift_profiles_fig(hrm, tag, lines) + + # residual signal within unsettled stratum: per-drift-decile AUC + uns = ~(hrm["cells"]["A"] | hrm["cells"]["B"]) + c = hrm["exact_correct"].astype(int) + lines.append("\n### HRM unsettled stratum: AUC(-lam1 -> correct) per log-drift decile") + lines.append("| decile | drift range (log10) | n | n_correct | AUC |") + lines.append("|---|---|---|---|---|") + ld_u, l1_u, c_u = hrm["logd_late"][uns], hrm["lyap_spec"][uns, 0], c[uns] + qs = np.percentile(ld_u, np.arange(0, 101, 10)) + aucs, ws = [], [] + for i in range(10): + m = (ld_u >= qs[i]) & (ld_u <= qs[i + 1] if i == 9 else ld_u < qs[i + 1]) + a = auc_rank(-l1_u[m], c_u[m]) + if not np.isnan(a) and c_u[m].sum() >= 5: + aucs.append(a); ws.append(m.sum()) + lines.append(f"| {i+1} | [{qs[i]:.2f}, {qs[i+1]:.2f}] | {int(m.sum())} | {int(c_u[m].sum())} | " + f"{a:.3f} |" if not np.isnan(a) else f"| {i+1} | [{qs[i]:.2f}, {qs[i+1]:.2f}] | {int(m.sum())} | {int(c_u[m].sum())} | n/a |") + if aucs: + lines.append(f"- weighted mean within-decile AUC = {np.average(aucs, weights=ws):.3f} " + f"(vs unconditioned within-unsettled AUC {auc_rank(-l1_u, c_u):.3f})") + + # also: does end-slope separate C from D? + lines.append(f"- AUC(-end_slope -> correct | unsettled) = {auc_rank(-slope[uns], c_u):.3f} " + f"(C still-descending fraction vs D, see slope table above)") + + # ---------- strict-B per-example table ---------- + B = hrm["cells"]["B"] + lines.append(f"\n## HRM strict-band settled-but-wrong examples (n={int(B.sum())})") + lines.append("| idx | givens | token_acc | lam1 | drift_final | halted_at | q_halt_final |") + lines.append("|---|---|---|---|---|---|---|") + bi = np.where(B)[0] + order = np.argsort(hrm["token_acc"][bi]) + for j in bi[order]: + lines.append( + f"| {int(hrm['idx'][j])} | {int(g_hrm[j])} | {hrm['token_acc'][j]:.3f} | {hrm['lyap_spec'][j,0]:+.3f} " + f"| {hrm['drift_zH'][j,-1]:.3f} | {int(hrm['halted_at'][j])} | {hrm['q_halt'][j,-1]:+.2f} |") + # B drift profiles vs A band + fig, ax = plt.subplots(figsize=(6.5, 4)) + steps = np.arange(1, 17) + A = hrm["cells"]["A"] + ax.fill_between(steps, np.percentile(hrm["drift_zH"][A], 10, axis=0), + np.percentile(hrm["drift_zH"][A], 90, axis=0), color="tab:green", alpha=0.2, + label=f"A q10-q90 (n={int(A.sum())})") + for j in bi: + ax.plot(steps, hrm["drift_zH"][j], "-", lw=1, alpha=0.8, color="tab:orange") + ax.set_yscale("log"); ax.set_xlabel("ACT step"); ax.set_ylabel("drift_zH") + ax.set_title("HRM: strict-B drift profiles vs A band"); ax.legend(fontsize=8) + fig.tight_layout(); fig.savefig(OUT / "fig_hrm_strictB_profiles.png", dpi=150); plt.close(fig) + + # ---------- difficulty control (HRM) ---------- + l1 = hrm["lyap_spec"][:, 0] + lines.append("\n## HRM difficulty control (#givens, input tokens != 1)") + lines.append(f"- givens: min {g_hrm.min()}, median {np.median(g_hrm):.0f}, max {g_hrm.max()}") + lines.append(f"- Spearman(lam1, givens): overall {spearman(l1, g_hrm):+.3f}; " + f"correct-only {spearman(l1[c==1], g_hrm[c==1]):+.3f}; " + f"wrong-only {spearman(l1[c==0], g_hrm[c==0]):+.3f}") + lines.append(f"- Spearman(correct, givens) = {spearman(c.astype(float), g_hrm):+.3f}") + lines.append("\n| givens bin | n | acc | AUC(-lam1 -> correct) |") + lines.append("|---|---|---|---|") + edges = np.unique(np.percentile(g_hrm, [0, 20, 40, 60, 80, 100])) + bin_aucs, bin_ws = [], [] + for i in range(len(edges) - 1): + m = (g_hrm >= edges[i]) & (g_hrm <= edges[i + 1] if i == len(edges) - 2 else g_hrm < edges[i + 1]) + a = auc_rank(-l1[m], c[m]) + lines.append(f"| [{edges[i]:.0f}, {edges[i+1]:.0f}] | {int(m.sum())} | {c[m].mean():.3f} | {a:.3f} |") + if not np.isnan(a): + bin_aucs.append(a); bin_ws.append(m.sum()) + lines.append(f"- weighted mean within-bin AUC = {np.average(bin_aucs, weights=bin_ws):.3f} (overall 0.984)") + + # ---------- TRM official @58590 (secondary, n=512) ---------- + trm = load(FLOSS / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz", strict_pct=60) + g_trm = givens_for(trm["idx"]) + ct = trm["exact_correct"].astype(int) + lines.append(f"\n## TRM official @58590 (n=512), strict tau(log10)={trm['tau_strict']:.4f}") + lines.append("| cell | n | lam1 med | token_acc med | q_halt_final med | givens med |") + lines.append("|---|---|---|---|---|---|") + for nm, m in trm["cells"].items(): + if m.sum() == 0: + lines.append(f"| {nm} | 0 | | | | |") + continue + lines.append(f"| {nm} | {int(m.sum())} | {np.median(trm['lyap_spec'][m,0]):+.4f} " + f"| {np.median(trm['token_acc'][m]):.3f} | {np.median(trm['q_halt'][m,-1]):+.2f} " + f"| {np.median(g_trm[m]):.0f} |") + drift_profiles_fig(trm, "trm_official58590_n512_strict", lines) + l1t = trm["lyap_spec"][:, 0] + lines.append(f"- Spearman(lam1, givens): overall {spearman(l1t, g_trm):+.3f}; " + f"wrong-only {spearman(l1t[ct==0], g_trm[ct==0]):+.3f}") + lines.append(f"- Spearman(correct, givens) = {spearman(ct.astype(float), g_trm):+.3f}") + + (OUT / "followups.md").write_text("\n".join(lines)) + print("\n".join(lines[:6])) + print("wrote", OUT / "followups.md") + + +if __name__ == "__main__": + main() |
