summaryrefslogtreecommitdiff
path: root/research/flossing/analysis_2x2/offline_followups.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/analysis_2x2/offline_followups.py')
-rw-r--r--research/flossing/analysis_2x2/offline_followups.py229
1 files changed, 229 insertions, 0 deletions
diff --git a/research/flossing/analysis_2x2/offline_followups.py b/research/flossing/analysis_2x2/offline_followups.py
new file mode 100644
index 0000000..8101acb
--- /dev/null
+++ b/research/flossing/analysis_2x2/offline_followups.py
@@ -0,0 +1,229 @@
+"""Offline follow-ups to the 2x2 analysis (no GPU):
+
+1. Residual outcome signal within the unsettled stratum (HRM diag_8k primary,
+ TRM official @58590 secondary): per-cell drift profiles over the 16 ACT steps,
+ end-of-window drift slope, q_halt trajectories, halted_at, lambda spectra with a
+ STRICT in-band threshold, and per-drift-decile AUC(lambda1 -> correct) within the
+ unsettled stratum (does lambda1 add signal beyond drift level?).
+2. Per-example profile of the strict-band settled-but-wrong examples (HRM, n~21).
+3. Difficulty control: #givens per puzzle (input tokens != 1) joined via idx;
+ lambda1 ~ givens rank correlation overall/within outcome, and per-givens-bin
+ AUC(-lambda1 -> correct).
+
+Observational only. Outputs to analysis_2x2/offline_followups/.
+"""
+from __future__ import annotations
+
+from pathlib import Path
+
+import matplotlib
+
+matplotlib.use("Agg")
+import matplotlib.pyplot as plt
+import numpy as np
+
+HERE = Path(__file__).resolve().parent
+FLOSS = HERE.parent
+OUT = HERE / "offline_followups"
+OUT.mkdir(exist_ok=True)
+
+DATA_TEST_INPUTS = Path("/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000/test/all__inputs.npy")
+
+CELL_COLORS = {"A": "tab:green", "B": "tab:orange", "C": "tab:blue", "D": "tab:red"}
+
+
+def auc_rank(score: np.ndarray, label: np.ndarray) -> float:
+ pos, neg = score[label == 1], score[label == 0]
+ if len(pos) == 0 or len(neg) == 0:
+ return float("nan")
+ allv = np.concatenate([pos, neg])
+ order = np.argsort(allv, kind="mergesort")
+ ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1)
+ sv = allv[order]; i = 0
+ while i < len(sv):
+ j = i
+ while j + 1 < len(sv) and sv[j + 1] == sv[i]:
+ j += 1
+ if j > i:
+ ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean()
+ i = j + 1
+ return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg)))
+
+
+def spearman(a: np.ndarray, b: np.ndarray) -> float:
+ ra = np.argsort(np.argsort(a)).astype(float)
+ rb = np.argsort(np.argsort(b)).astype(float)
+ return float(np.corrcoef(ra, rb)[0, 1])
+
+
+def load(npz_path: Path, strict_pct: float):
+ d = np.load(npz_path)
+ out = {k: d[k] for k in d.files}
+ out["logd_late"] = np.log10(np.clip(out["drift_zH"][:, -4:].mean(1), 1e-12, None))
+ out["tau_strict"] = float(np.percentile(out["logd_late"], strict_pct))
+ conv = out["logd_late"] < out["tau_strict"]
+ c = out["exact_correct"].astype(int)
+ out["cells"] = {
+ "A": conv & (c == 1), "B": conv & (c == 0),
+ "C": (~conv) & (c == 1), "D": (~conv) & (c == 0),
+ }
+ return out
+
+
+def givens_for(idx: np.ndarray) -> np.ndarray:
+ inputs = np.load(DATA_TEST_INPUTS, mmap_mode="r")
+ return np.array([(inputs[i] != 1).sum() for i in idx])
+
+
+def drift_profiles_fig(ds, tag, lines):
+ fig, axes = plt.subplots(1, 2, figsize=(11, 4))
+ steps = np.arange(1, ds["drift_zH"].shape[1] + 1)
+ for nm, m in ds["cells"].items():
+ if m.sum() < 3:
+ continue
+ med = np.median(ds["drift_zH"][m], axis=0)
+ q1 = np.percentile(ds["drift_zH"][m], 25, axis=0)
+ q3 = np.percentile(ds["drift_zH"][m], 75, axis=0)
+ axes[0].plot(steps, med, "o-", ms=3, color=CELL_COLORS[nm], label=f"{nm} (n={int(m.sum())})")
+ axes[0].fill_between(steps, q1, q3, color=CELL_COLORS[nm], alpha=0.15)
+ qm = np.median(ds["q_halt"][m], axis=0)
+ axes[1].plot(steps, qm, "o-", ms=3, color=CELL_COLORS[nm], label=nm)
+ axes[0].set_yscale("log"); axes[0].set_xlabel("ACT step"); axes[0].set_ylabel("drift_zH (median, IQR)")
+ axes[0].legend(fontsize=8); axes[0].set_title(f"{tag}: drift profiles per cell")
+ axes[1].set_xlabel("ACT step"); axes[1].set_ylabel("q_halt (median)"); axes[1].axhline(0, color="gray", lw=0.6)
+ axes[1].set_title(f"{tag}: q_halt per cell"); axes[1].legend(fontsize=8)
+ fig.tight_layout(); fig.savefig(OUT / f"fig_{tag}_profiles.png", dpi=150); plt.close(fig)
+
+ # end-of-window slope: log10 mean(drift[13:16]) - log10 mean(drift[9:12])
+ slope = (np.log10(np.clip(ds["drift_zH"][:, 12:16].mean(1), 1e-12, None))
+ - np.log10(np.clip(ds["drift_zH"][:, 8:12].mean(1), 1e-12, None)))
+ lines.append(f"\n### {tag}: end-of-window drift slope (log10 steps13-16 vs 9-12; <0 = still descending)")
+ for nm, m in ds["cells"].items():
+ if m.sum() == 0:
+ lines.append(f"- {nm}: n=0")
+ continue
+ lines.append(f"- {nm}: n={int(m.sum())}, slope median {np.median(slope[m]):+.4f}, "
+ f"IQR [{np.percentile(slope[m],25):+.4f}, {np.percentile(slope[m],75):+.4f}], "
+ f"frac still descending (<-0.01): {float((slope[m] < -0.01).mean()):.2f}")
+ return slope
+
+
+def main() -> None:
+ lines = ["# Offline follow-ups (no GPU) — 2026-06-11",
+ "",
+ "Strict in-band thresholds: HRM pct45 of pooled log10 late-drift; TRM pct60 (band edge; B=0 regardless).",
+ "All numbers observational; within-dataset comparisons only."]
+
+ # ---------- HRM diag_8k ----------
+ hrm = load(FLOSS / "diag_8k.npz", strict_pct=45)
+ tag = "hrm26040_n8192_strict"
+ lines.append(f"\n## HRM @26040 (n=8192), strict tau(log10)={hrm['tau_strict']:.4f}")
+ lines.append("| cell | n | lam1 med | lam8 med | token_acc med | halted_at med | q_halt_final med | givens med |")
+ lines.append("|---|---|---|---|---|---|---|---|")
+ g_hrm = givens_for(hrm["idx"])
+ for nm, m in hrm["cells"].items():
+ if m.sum() == 0:
+ lines.append(f"| {nm} | 0 | | | | | | |")
+ continue
+ lines.append(
+ f"| {nm} | {int(m.sum())} | {np.median(hrm['lyap_spec'][m,0]):+.4f} | {np.median(hrm['lyap_spec'][m,-1]):+.4f} "
+ f"| {np.median(hrm['token_acc'][m]):.3f} | {np.median(hrm['halted_at'][m]):.0f} "
+ f"| {np.median(hrm['q_halt'][m,-1]):+.2f} | {np.median(g_hrm[m]):.0f} |")
+
+ slope = drift_profiles_fig(hrm, tag, lines)
+
+ # residual signal within unsettled stratum: per-drift-decile AUC
+ uns = ~(hrm["cells"]["A"] | hrm["cells"]["B"])
+ c = hrm["exact_correct"].astype(int)
+ lines.append("\n### HRM unsettled stratum: AUC(-lam1 -> correct) per log-drift decile")
+ lines.append("| decile | drift range (log10) | n | n_correct | AUC |")
+ lines.append("|---|---|---|---|---|")
+ ld_u, l1_u, c_u = hrm["logd_late"][uns], hrm["lyap_spec"][uns, 0], c[uns]
+ qs = np.percentile(ld_u, np.arange(0, 101, 10))
+ aucs, ws = [], []
+ for i in range(10):
+ m = (ld_u >= qs[i]) & (ld_u <= qs[i + 1] if i == 9 else ld_u < qs[i + 1])
+ a = auc_rank(-l1_u[m], c_u[m])
+ if not np.isnan(a) and c_u[m].sum() >= 5:
+ aucs.append(a); ws.append(m.sum())
+ lines.append(f"| {i+1} | [{qs[i]:.2f}, {qs[i+1]:.2f}] | {int(m.sum())} | {int(c_u[m].sum())} | "
+ f"{a:.3f} |" if not np.isnan(a) else f"| {i+1} | [{qs[i]:.2f}, {qs[i+1]:.2f}] | {int(m.sum())} | {int(c_u[m].sum())} | n/a |")
+ if aucs:
+ lines.append(f"- weighted mean within-decile AUC = {np.average(aucs, weights=ws):.3f} "
+ f"(vs unconditioned within-unsettled AUC {auc_rank(-l1_u, c_u):.3f})")
+
+ # also: does end-slope separate C from D?
+ lines.append(f"- AUC(-end_slope -> correct | unsettled) = {auc_rank(-slope[uns], c_u):.3f} "
+ f"(C still-descending fraction vs D, see slope table above)")
+
+ # ---------- strict-B per-example table ----------
+ B = hrm["cells"]["B"]
+ lines.append(f"\n## HRM strict-band settled-but-wrong examples (n={int(B.sum())})")
+ lines.append("| idx | givens | token_acc | lam1 | drift_final | halted_at | q_halt_final |")
+ lines.append("|---|---|---|---|---|---|---|")
+ bi = np.where(B)[0]
+ order = np.argsort(hrm["token_acc"][bi])
+ for j in bi[order]:
+ lines.append(
+ f"| {int(hrm['idx'][j])} | {int(g_hrm[j])} | {hrm['token_acc'][j]:.3f} | {hrm['lyap_spec'][j,0]:+.3f} "
+ f"| {hrm['drift_zH'][j,-1]:.3f} | {int(hrm['halted_at'][j])} | {hrm['q_halt'][j,-1]:+.2f} |")
+ # B drift profiles vs A band
+ fig, ax = plt.subplots(figsize=(6.5, 4))
+ steps = np.arange(1, 17)
+ A = hrm["cells"]["A"]
+ ax.fill_between(steps, np.percentile(hrm["drift_zH"][A], 10, axis=0),
+ np.percentile(hrm["drift_zH"][A], 90, axis=0), color="tab:green", alpha=0.2,
+ label=f"A q10-q90 (n={int(A.sum())})")
+ for j in bi:
+ ax.plot(steps, hrm["drift_zH"][j], "-", lw=1, alpha=0.8, color="tab:orange")
+ ax.set_yscale("log"); ax.set_xlabel("ACT step"); ax.set_ylabel("drift_zH")
+ ax.set_title("HRM: strict-B drift profiles vs A band"); ax.legend(fontsize=8)
+ fig.tight_layout(); fig.savefig(OUT / "fig_hrm_strictB_profiles.png", dpi=150); plt.close(fig)
+
+ # ---------- difficulty control (HRM) ----------
+ l1 = hrm["lyap_spec"][:, 0]
+ lines.append("\n## HRM difficulty control (#givens, input tokens != 1)")
+ lines.append(f"- givens: min {g_hrm.min()}, median {np.median(g_hrm):.0f}, max {g_hrm.max()}")
+ lines.append(f"- Spearman(lam1, givens): overall {spearman(l1, g_hrm):+.3f}; "
+ f"correct-only {spearman(l1[c==1], g_hrm[c==1]):+.3f}; "
+ f"wrong-only {spearman(l1[c==0], g_hrm[c==0]):+.3f}")
+ lines.append(f"- Spearman(correct, givens) = {spearman(c.astype(float), g_hrm):+.3f}")
+ lines.append("\n| givens bin | n | acc | AUC(-lam1 -> correct) |")
+ lines.append("|---|---|---|---|")
+ edges = np.unique(np.percentile(g_hrm, [0, 20, 40, 60, 80, 100]))
+ bin_aucs, bin_ws = [], []
+ for i in range(len(edges) - 1):
+ m = (g_hrm >= edges[i]) & (g_hrm <= edges[i + 1] if i == len(edges) - 2 else g_hrm < edges[i + 1])
+ a = auc_rank(-l1[m], c[m])
+ lines.append(f"| [{edges[i]:.0f}, {edges[i+1]:.0f}] | {int(m.sum())} | {c[m].mean():.3f} | {a:.3f} |")
+ if not np.isnan(a):
+ bin_aucs.append(a); bin_ws.append(m.sum())
+ lines.append(f"- weighted mean within-bin AUC = {np.average(bin_aucs, weights=bin_ws):.3f} (overall 0.984)")
+
+ # ---------- TRM official @58590 (secondary, n=512) ----------
+ trm = load(FLOSS / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz", strict_pct=60)
+ g_trm = givens_for(trm["idx"])
+ ct = trm["exact_correct"].astype(int)
+ lines.append(f"\n## TRM official @58590 (n=512), strict tau(log10)={trm['tau_strict']:.4f}")
+ lines.append("| cell | n | lam1 med | token_acc med | q_halt_final med | givens med |")
+ lines.append("|---|---|---|---|---|---|")
+ for nm, m in trm["cells"].items():
+ if m.sum() == 0:
+ lines.append(f"| {nm} | 0 | | | | |")
+ continue
+ lines.append(f"| {nm} | {int(m.sum())} | {np.median(trm['lyap_spec'][m,0]):+.4f} "
+ f"| {np.median(trm['token_acc'][m]):.3f} | {np.median(trm['q_halt'][m,-1]):+.2f} "
+ f"| {np.median(g_trm[m]):.0f} |")
+ drift_profiles_fig(trm, "trm_official58590_n512_strict", lines)
+ l1t = trm["lyap_spec"][:, 0]
+ lines.append(f"- Spearman(lam1, givens): overall {spearman(l1t, g_trm):+.3f}; "
+ f"wrong-only {spearman(l1t[ct==0], g_trm[ct==0]):+.3f}")
+ lines.append(f"- Spearman(correct, givens) = {spearman(ct.astype(float), g_trm):+.3f}")
+
+ (OUT / "followups.md").write_text("\n".join(lines))
+ print("\n".join(lines[:6]))
+ print("wrote", OUT / "followups.md")
+
+
+if __name__ == "__main__":
+ main()