summaryrefslogtreecommitdiff
path: root/research/flossing/analysis_2x2/offline_phase1_e1.py
diff options
context:
space:
mode:
Diffstat (limited to 'research/flossing/analysis_2x2/offline_phase1_e1.py')
-rw-r--r--research/flossing/analysis_2x2/offline_phase1_e1.py123
1 files changed, 123 insertions, 0 deletions
diff --git a/research/flossing/analysis_2x2/offline_phase1_e1.py b/research/flossing/analysis_2x2/offline_phase1_e1.py
new file mode 100644
index 0000000..7c42825
--- /dev/null
+++ b/research/flossing/analysis_2x2/offline_phase1_e1.py
@@ -0,0 +1,123 @@
+"""E1 offline batch (experiment_framework.md):
+(1) bootstrap CIs for headline quantities; (2) settling-criterion robustness (z_L / combined);
+(3) TRM official multi4 vs baseline matched-pipeline 2x2; (4) provenance note for hrm_multi4.
+Outputs: offline_followups/phase1_e1.md
+"""
+from __future__ import annotations
+
+from pathlib import Path
+
+import numpy as np
+
+HERE = Path(__file__).resolve().parent
+FLOSS = HERE.parent
+OUT = HERE / "offline_followups"
+RNG = np.random.default_rng(0)
+
+
+def auc_rank(score, label):
+ pos, neg = score[label == 1], score[label == 0]
+ if len(pos) == 0 or len(neg) == 0:
+ return float("nan")
+ allv = np.concatenate([pos, neg])
+ order = np.argsort(allv, kind="mergesort")
+ ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1)
+ sv = allv[order]; i = 0
+ while i < len(sv):
+ j = i
+ while j + 1 < len(sv) and sv[j + 1] == sv[i]:
+ j += 1
+ if j > i:
+ ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean()
+ i = j + 1
+ return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg)))
+
+
+def boot_ci(stat_fn, n, B=10000):
+ vals = []
+ for _ in range(B):
+ idx = RNG.integers(0, n, n)
+ v = stat_fn(idx)
+ if not np.isnan(v):
+ vals.append(v)
+ return float(np.percentile(vals, 2.5)), float(np.percentile(vals, 97.5))
+
+
+def late(d, key="drift_zH"):
+ return np.log10(np.clip(d[key][:, -4:].mean(1), 1e-12, None))
+
+
+def otsu(x, nbins=256):
+ h, e = np.histogram(x, bins=nbins); h = h.astype(float)
+ c = (e[:-1] + e[1:]) / 2
+ p = h / h.sum(); om = np.cumsum(p); mu = np.cumsum(p * c); mt = mu[-1]
+ den = om * (1 - om); den[den <= 0] = np.nan
+ return float(c[np.nanargmax((mt * om - mu) ** 2 / den)])
+
+
+lines = ["# E1 offline batch — bootstrap CIs, settling robustness, TRM multi4 pair", ""]
+
+# ---------- (1) bootstrap CIs ----------
+trm = np.load(FLOSS / "analysis_2x2/retest/trm_gbs768_step58590_full_n2048.npz")
+c_t = trm["exact_correct"].astype(int); l1_t = trm["lyap_spec"][:, 0]; ld_t = late(trm)
+n_wrong = int((c_t == 0).sum())
+cp_upper = 1 - 0.05 ** (1 / n_wrong) # Clopper-Pearson upper for 0 events
+lines += ["## Bootstrap / exact CIs (TRM official @58590, n=2048)",
+ f"- settled-wrong fraction: observed 0/{n_wrong}; exact 95% upper bound {cp_upper:.4f} "
+ f"({cp_upper*100:.2f}% of failures)",
+ f"- AUC(-lam1->correct) = {auc_rank(-l1_t, c_t):.4f}, bootstrap 95% CI "
+ f"{boot_ci(lambda i: auc_rank(-l1_t[i], c_t[i]), len(c_t))}",
+ f"- lam1(wrong) median 95% CI {boot_ci(lambda i: float(np.median(l1_t[i][c_t[i]==0])), len(c_t))}",
+ f"- lam1(correct) median 95% CI {boot_ci(lambda i: float(np.median(l1_t[i][c_t[i]==1])), len(c_t))}"]
+
+hrm = np.load(FLOSS / "diag_8k.npz")
+c_h = hrm["exact_correct"].astype(int); l1_h = hrm["lyap_spec"][:, 0]; ld_h = late(hrm)
+tau_strict = float(np.percentile(ld_h, 45))
+def strictB_frac(idx):
+ c, ld = c_h[idx], ld_h[idx]
+ w = (c == 0).sum()
+ return float(((ld < tau_strict) & (c == 0)).sum() / max(w, 1))
+lines += ["", "## Bootstrap CIs (HRM @26040, n=8192, strict band)",
+ f"- strict settled-wrong fraction of failures: observed {21/3894:.4f}, bootstrap 95% CI "
+ f"{boot_ci(strictB_frac, len(c_h))}",
+ f"- AUC(-lam1->correct) = {auc_rank(-l1_h, c_h):.4f}, bootstrap 95% CI "
+ f"{boot_ci(lambda i: auc_rank(-l1_h[i], c_h[i]), len(c_h))}"]
+
+# ---------- (2) settling-criterion robustness ----------
+lines += ["", "## Settling-criterion robustness (B-cell counts under alternative drift definitions)"]
+for tag, d, c in [("TRM official n=2048", trm, c_t), ("HRM n=8192", hrm, c_h)]:
+ row = [tag]
+ for key, nm in [("drift_zH", "zH"), ("drift_zL", "zL")]:
+ ld = late(d, key); tau = otsu(ld); conv = ld < tau
+ row.append(f"{nm}: B={int((conv & (c==0)).sum())}/A={int((conv & (c==1)).sum())} (tau={tau:.2f})")
+ comb = np.log10(np.clip(np.sqrt(d["drift_zH"][:, -4:].mean(1) ** 2 + d["drift_zL"][:, -4:].mean(1) ** 2), 1e-12, None))
+ tau = otsu(comb); conv = comb < tau
+ row.append(f"combined: B={int((conv & (c==0)).sum())}/A={int((conv & (c==1)).sum())} (tau={tau:.2f})")
+ lines.append("- " + " | ".join(row))
+
+# ---------- (3) TRM official multi4 vs baseline (matched pipeline) ----------
+lines += ["", "## TRM official-pipeline multi4 vs baseline (matched objective, n=512 each)"]
+spec = FLOSS / "official_gbs768_spectrum"
+for nm, f in [("baseline @58590", spec / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"),
+ ("multi4 @35805 (best)", spec / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"),
+ ("multi4 @65100 (final)", spec / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz")]:
+ if not f.exists():
+ lines.append(f"- {nm}: npz missing"); continue
+ d = np.load(f); c = d["exact_correct"].astype(int); l1 = d["lyap_spec"][:, 0]
+ ld = late(d); tau = otsu(ld); conv = ld < tau
+ A = conv & (c == 1); B = conv & (c == 0); C = (~conv) & (c == 1); D = (~conv) & (c == 0)
+ lines.append(f"- {nm}: acc={c.mean():.3f}; A/B/C/D={int(A.sum())}/{int(B.sum())}/{int(C.sum())}/{int(D.sum())}; "
+ f"fD={float(D.mean()):.3f}; lam1(D)={np.median(l1[D]) if D.sum()>0 else float('nan'):+.4f}; "
+ f"lam1(A)={np.median(l1[A]) if A.sum()>0 else float('nan'):+.4f}")
+
+# ---------- (4) provenance note ----------
+lines += ["", "## hrm_multi4 provenance (E6a)",
+ "- diag_hrm_multi4_step_{20832,23436,26040}_512.npz step grid matches HRM pretrain numbering;",
+ " multi4_eval_compare/logs should contain the eval invocations — checked manually below.",
+ "- ACTION: if the hrm_multi4 run is pretrain-pipeline (ACT-streaming + perturbation), then the",
+ " May-28 multi4 vs righteous baseline comparison IS matched-pipeline and Sec 3.4's caveat is",
+ " narrower than written; step9 E-vs-F pair (queued) covers the fixed-unroll objective regardless."]
+
+OUT.mkdir(exist_ok=True)
+(OUT / "phase1_e1.md").write_text("\n".join(lines))
+print("\n".join(lines))