"""Offline follow-ups to the 2x2 analysis (no GPU): 1. Residual outcome signal within the unsettled stratum (HRM diag_8k primary, TRM official @58590 secondary): per-cell drift profiles over the 16 ACT steps, end-of-window drift slope, q_halt trajectories, halted_at, lambda spectra with a STRICT in-band threshold, and per-drift-decile AUC(lambda1 -> correct) within the unsettled stratum (does lambda1 add signal beyond drift level?). 2. Per-example profile of the strict-band settled-but-wrong examples (HRM, n~21). 3. Difficulty control: #givens per puzzle (input tokens != 1) joined via idx; lambda1 ~ givens rank correlation overall/within outcome, and per-givens-bin AUC(-lambda1 -> correct). Observational only. Outputs to analysis_2x2/offline_followups/. """ from __future__ import annotations from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np HERE = Path(__file__).resolve().parent FLOSS = HERE.parent OUT = HERE / "offline_followups" OUT.mkdir(exist_ok=True) DATA_TEST_INPUTS = Path("/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000/test/all__inputs.npy") CELL_COLORS = {"A": "tab:green", "B": "tab:orange", "C": "tab:blue", "D": "tab:red"} def auc_rank(score: np.ndarray, label: np.ndarray) -> float: pos, neg = score[label == 1], score[label == 0] if len(pos) == 0 or len(neg) == 0: return float("nan") allv = np.concatenate([pos, neg]) order = np.argsort(allv, kind="mergesort") ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1) sv = allv[order]; i = 0 while i < len(sv): j = i while j + 1 < len(sv) and sv[j + 1] == sv[i]: j += 1 if j > i: ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean() i = j + 1 return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg))) def spearman(a: np.ndarray, b: np.ndarray) -> float: ra = np.argsort(np.argsort(a)).astype(float) rb = np.argsort(np.argsort(b)).astype(float) return float(np.corrcoef(ra, rb)[0, 1]) def load(npz_path: Path, strict_pct: float): d = np.load(npz_path) out = {k: d[k] for k in d.files} out["logd_late"] = np.log10(np.clip(out["drift_zH"][:, -4:].mean(1), 1e-12, None)) out["tau_strict"] = float(np.percentile(out["logd_late"], strict_pct)) conv = out["logd_late"] < out["tau_strict"] c = out["exact_correct"].astype(int) out["cells"] = { "A": conv & (c == 1), "B": conv & (c == 0), "C": (~conv) & (c == 1), "D": (~conv) & (c == 0), } return out def givens_for(idx: np.ndarray) -> np.ndarray: inputs = np.load(DATA_TEST_INPUTS, mmap_mode="r") return np.array([(inputs[i] != 1).sum() for i in idx]) def drift_profiles_fig(ds, tag, lines): fig, axes = plt.subplots(1, 2, figsize=(11, 4)) steps = np.arange(1, ds["drift_zH"].shape[1] + 1) for nm, m in ds["cells"].items(): if m.sum() < 3: continue med = np.median(ds["drift_zH"][m], axis=0) q1 = np.percentile(ds["drift_zH"][m], 25, axis=0) q3 = np.percentile(ds["drift_zH"][m], 75, axis=0) axes[0].plot(steps, med, "o-", ms=3, color=CELL_COLORS[nm], label=f"{nm} (n={int(m.sum())})") axes[0].fill_between(steps, q1, q3, color=CELL_COLORS[nm], alpha=0.15) qm = np.median(ds["q_halt"][m], axis=0) axes[1].plot(steps, qm, "o-", ms=3, color=CELL_COLORS[nm], label=nm) axes[0].set_yscale("log"); axes[0].set_xlabel("ACT step"); axes[0].set_ylabel("drift_zH (median, IQR)") axes[0].legend(fontsize=8); axes[0].set_title(f"{tag}: drift profiles per cell") axes[1].set_xlabel("ACT step"); axes[1].set_ylabel("q_halt (median)"); axes[1].axhline(0, color="gray", lw=0.6) axes[1].set_title(f"{tag}: q_halt per cell"); axes[1].legend(fontsize=8) fig.tight_layout(); fig.savefig(OUT / f"fig_{tag}_profiles.png", dpi=150); plt.close(fig) # end-of-window slope: log10 mean(drift[13:16]) - log10 mean(drift[9:12]) slope = (np.log10(np.clip(ds["drift_zH"][:, 12:16].mean(1), 1e-12, None)) - np.log10(np.clip(ds["drift_zH"][:, 8:12].mean(1), 1e-12, None))) lines.append(f"\n### {tag}: end-of-window drift slope (log10 steps13-16 vs 9-12; <0 = still descending)") for nm, m in ds["cells"].items(): if m.sum() == 0: lines.append(f"- {nm}: n=0") continue lines.append(f"- {nm}: n={int(m.sum())}, slope median {np.median(slope[m]):+.4f}, " f"IQR [{np.percentile(slope[m],25):+.4f}, {np.percentile(slope[m],75):+.4f}], " f"frac still descending (<-0.01): {float((slope[m] < -0.01).mean()):.2f}") return slope def main() -> None: lines = ["# Offline follow-ups (no GPU) — 2026-06-11", "", "Strict in-band thresholds: HRM pct45 of pooled log10 late-drift; TRM pct60 (band edge; B=0 regardless).", "All numbers observational; within-dataset comparisons only."] # ---------- HRM diag_8k ---------- hrm = load(FLOSS / "diag_8k.npz", strict_pct=45) tag = "hrm26040_n8192_strict" lines.append(f"\n## HRM @26040 (n=8192), strict tau(log10)={hrm['tau_strict']:.4f}") lines.append("| cell | n | lam1 med | lam8 med | token_acc med | halted_at med | q_halt_final med | givens med |") lines.append("|---|---|---|---|---|---|---|---|") g_hrm = givens_for(hrm["idx"]) for nm, m in hrm["cells"].items(): if m.sum() == 0: lines.append(f"| {nm} | 0 | | | | | | |") continue lines.append( f"| {nm} | {int(m.sum())} | {np.median(hrm['lyap_spec'][m,0]):+.4f} | {np.median(hrm['lyap_spec'][m,-1]):+.4f} " f"| {np.median(hrm['token_acc'][m]):.3f} | {np.median(hrm['halted_at'][m]):.0f} " f"| {np.median(hrm['q_halt'][m,-1]):+.2f} | {np.median(g_hrm[m]):.0f} |") slope = drift_profiles_fig(hrm, tag, lines) # residual signal within unsettled stratum: per-drift-decile AUC uns = ~(hrm["cells"]["A"] | hrm["cells"]["B"]) c = hrm["exact_correct"].astype(int) lines.append("\n### HRM unsettled stratum: AUC(-lam1 -> correct) per log-drift decile") lines.append("| decile | drift range (log10) | n | n_correct | AUC |") lines.append("|---|---|---|---|---|") ld_u, l1_u, c_u = hrm["logd_late"][uns], hrm["lyap_spec"][uns, 0], c[uns] qs = np.percentile(ld_u, np.arange(0, 101, 10)) aucs, ws = [], [] for i in range(10): m = (ld_u >= qs[i]) & (ld_u <= qs[i + 1] if i == 9 else ld_u < qs[i + 1]) a = auc_rank(-l1_u[m], c_u[m]) if not np.isnan(a) and c_u[m].sum() >= 5: aucs.append(a); ws.append(m.sum()) lines.append(f"| {i+1} | [{qs[i]:.2f}, {qs[i+1]:.2f}] | {int(m.sum())} | {int(c_u[m].sum())} | " f"{a:.3f} |" if not np.isnan(a) else f"| {i+1} | [{qs[i]:.2f}, {qs[i+1]:.2f}] | {int(m.sum())} | {int(c_u[m].sum())} | n/a |") if aucs: lines.append(f"- weighted mean within-decile AUC = {np.average(aucs, weights=ws):.3f} " f"(vs unconditioned within-unsettled AUC {auc_rank(-l1_u, c_u):.3f})") # also: does end-slope separate C from D? lines.append(f"- AUC(-end_slope -> correct | unsettled) = {auc_rank(-slope[uns], c_u):.3f} " f"(C still-descending fraction vs D, see slope table above)") # ---------- strict-B per-example table ---------- B = hrm["cells"]["B"] lines.append(f"\n## HRM strict-band settled-but-wrong examples (n={int(B.sum())})") lines.append("| idx | givens | token_acc | lam1 | drift_final | halted_at | q_halt_final |") lines.append("|---|---|---|---|---|---|---|") bi = np.where(B)[0] order = np.argsort(hrm["token_acc"][bi]) for j in bi[order]: lines.append( f"| {int(hrm['idx'][j])} | {int(g_hrm[j])} | {hrm['token_acc'][j]:.3f} | {hrm['lyap_spec'][j,0]:+.3f} " f"| {hrm['drift_zH'][j,-1]:.3f} | {int(hrm['halted_at'][j])} | {hrm['q_halt'][j,-1]:+.2f} |") # B drift profiles vs A band fig, ax = plt.subplots(figsize=(6.5, 4)) steps = np.arange(1, 17) A = hrm["cells"]["A"] ax.fill_between(steps, np.percentile(hrm["drift_zH"][A], 10, axis=0), np.percentile(hrm["drift_zH"][A], 90, axis=0), color="tab:green", alpha=0.2, label=f"A q10-q90 (n={int(A.sum())})") for j in bi: ax.plot(steps, hrm["drift_zH"][j], "-", lw=1, alpha=0.8, color="tab:orange") ax.set_yscale("log"); ax.set_xlabel("ACT step"); ax.set_ylabel("drift_zH") ax.set_title("HRM: strict-B drift profiles vs A band"); ax.legend(fontsize=8) fig.tight_layout(); fig.savefig(OUT / "fig_hrm_strictB_profiles.png", dpi=150); plt.close(fig) # ---------- difficulty control (HRM) ---------- l1 = hrm["lyap_spec"][:, 0] lines.append("\n## HRM difficulty control (#givens, input tokens != 1)") lines.append(f"- givens: min {g_hrm.min()}, median {np.median(g_hrm):.0f}, max {g_hrm.max()}") lines.append(f"- Spearman(lam1, givens): overall {spearman(l1, g_hrm):+.3f}; " f"correct-only {spearman(l1[c==1], g_hrm[c==1]):+.3f}; " f"wrong-only {spearman(l1[c==0], g_hrm[c==0]):+.3f}") lines.append(f"- Spearman(correct, givens) = {spearman(c.astype(float), g_hrm):+.3f}") lines.append("\n| givens bin | n | acc | AUC(-lam1 -> correct) |") lines.append("|---|---|---|---|") edges = np.unique(np.percentile(g_hrm, [0, 20, 40, 60, 80, 100])) bin_aucs, bin_ws = [], [] for i in range(len(edges) - 1): m = (g_hrm >= edges[i]) & (g_hrm <= edges[i + 1] if i == len(edges) - 2 else g_hrm < edges[i + 1]) a = auc_rank(-l1[m], c[m]) lines.append(f"| [{edges[i]:.0f}, {edges[i+1]:.0f}] | {int(m.sum())} | {c[m].mean():.3f} | {a:.3f} |") if not np.isnan(a): bin_aucs.append(a); bin_ws.append(m.sum()) lines.append(f"- weighted mean within-bin AUC = {np.average(bin_aucs, weights=bin_ws):.3f} (overall 0.984)") # ---------- TRM official @58590 (secondary, n=512) ---------- trm = load(FLOSS / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz", strict_pct=60) g_trm = givens_for(trm["idx"]) ct = trm["exact_correct"].astype(int) lines.append(f"\n## TRM official @58590 (n=512), strict tau(log10)={trm['tau_strict']:.4f}") lines.append("| cell | n | lam1 med | token_acc med | q_halt_final med | givens med |") lines.append("|---|---|---|---|---|---|") for nm, m in trm["cells"].items(): if m.sum() == 0: lines.append(f"| {nm} | 0 | | | | |") continue lines.append(f"| {nm} | {int(m.sum())} | {np.median(trm['lyap_spec'][m,0]):+.4f} " f"| {np.median(trm['token_acc'][m]):.3f} | {np.median(trm['q_halt'][m,-1]):+.2f} " f"| {np.median(g_trm[m]):.0f} |") drift_profiles_fig(trm, "trm_official58590_n512_strict", lines) l1t = trm["lyap_spec"][:, 0] lines.append(f"- Spearman(lam1, givens): overall {spearman(l1t, g_trm):+.3f}; " f"wrong-only {spearman(l1t[ct==0], g_trm[ct==0]):+.3f}") lines.append(f"- Spearman(correct, givens) = {spearman(ct.astype(float), g_trm):+.3f}") (OUT / "followups.md").write_text("\n".join(lines)) print("\n".join(lines[:6])) print("wrote", OUT / "followups.md") if __name__ == "__main__": main()