From 66e0d8b9fd4d0f7a2231d689c055e26fdf1cf04a Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sat, 13 Jun 2026 12:35:36 -0500 Subject: rrm workspace: TRM/HRM/SRM code, Maze dataset, dynamical-analysis pipeline Curated export for clone-and-run Maze training (2x A6000) + diagnostics. trm/hrm pretrain.py carry trajectory-augmentation code (backward-compatible). Heavy artifacts (checkpoints/wandb/npz) gitignored; see PROVENANCE.md. Co-Authored-By: Claude Fable 5 --- .../flossing/analysis_2x2/offline_phase1_e1.py | 123 +++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 research/flossing/analysis_2x2/offline_phase1_e1.py (limited to 'research/flossing/analysis_2x2/offline_phase1_e1.py') diff --git a/research/flossing/analysis_2x2/offline_phase1_e1.py b/research/flossing/analysis_2x2/offline_phase1_e1.py new file mode 100644 index 0000000..7c42825 --- /dev/null +++ b/research/flossing/analysis_2x2/offline_phase1_e1.py @@ -0,0 +1,123 @@ +"""E1 offline batch (experiment_framework.md): +(1) bootstrap CIs for headline quantities; (2) settling-criterion robustness (z_L / combined); +(3) TRM official multi4 vs baseline matched-pipeline 2x2; (4) provenance note for hrm_multi4. +Outputs: offline_followups/phase1_e1.md +""" +from __future__ import annotations + +from pathlib import Path + +import numpy as np + +HERE = Path(__file__).resolve().parent +FLOSS = HERE.parent +OUT = HERE / "offline_followups" +RNG = np.random.default_rng(0) + + +def auc_rank(score, label): + pos, neg = score[label == 1], score[label == 0] + if len(pos) == 0 or len(neg) == 0: + return float("nan") + allv = np.concatenate([pos, neg]) + order = np.argsort(allv, kind="mergesort") + ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1) + sv = allv[order]; i = 0 + while i < len(sv): + j = i + while j + 1 < len(sv) and sv[j + 1] == sv[i]: + j += 1 + if j > i: + ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean() + i = j + 1 + return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg))) + + +def boot_ci(stat_fn, n, B=10000): + vals = [] + for _ in range(B): + idx = RNG.integers(0, n, n) + v = stat_fn(idx) + if not np.isnan(v): + vals.append(v) + return float(np.percentile(vals, 2.5)), float(np.percentile(vals, 97.5)) + + +def late(d, key="drift_zH"): + return np.log10(np.clip(d[key][:, -4:].mean(1), 1e-12, None)) + + +def otsu(x, nbins=256): + h, e = np.histogram(x, bins=nbins); h = h.astype(float) + c = (e[:-1] + e[1:]) / 2 + p = h / h.sum(); om = np.cumsum(p); mu = np.cumsum(p * c); mt = mu[-1] + den = om * (1 - om); den[den <= 0] = np.nan + return float(c[np.nanargmax((mt * om - mu) ** 2 / den)]) + + +lines = ["# E1 offline batch — bootstrap CIs, settling robustness, TRM multi4 pair", ""] + +# ---------- (1) bootstrap CIs ---------- +trm = np.load(FLOSS / "analysis_2x2/retest/trm_gbs768_step58590_full_n2048.npz") +c_t = trm["exact_correct"].astype(int); l1_t = trm["lyap_spec"][:, 0]; ld_t = late(trm) +n_wrong = int((c_t == 0).sum()) +cp_upper = 1 - 0.05 ** (1 / n_wrong) # Clopper-Pearson upper for 0 events +lines += ["## Bootstrap / exact CIs (TRM official @58590, n=2048)", + f"- settled-wrong fraction: observed 0/{n_wrong}; exact 95% upper bound {cp_upper:.4f} " + f"({cp_upper*100:.2f}% of failures)", + f"- AUC(-lam1->correct) = {auc_rank(-l1_t, c_t):.4f}, bootstrap 95% CI " + f"{boot_ci(lambda i: auc_rank(-l1_t[i], c_t[i]), len(c_t))}", + f"- lam1(wrong) median 95% CI {boot_ci(lambda i: float(np.median(l1_t[i][c_t[i]==0])), len(c_t))}", + f"- lam1(correct) median 95% CI {boot_ci(lambda i: float(np.median(l1_t[i][c_t[i]==1])), len(c_t))}"] + +hrm = np.load(FLOSS / "diag_8k.npz") +c_h = hrm["exact_correct"].astype(int); l1_h = hrm["lyap_spec"][:, 0]; ld_h = late(hrm) +tau_strict = float(np.percentile(ld_h, 45)) +def strictB_frac(idx): + c, ld = c_h[idx], ld_h[idx] + w = (c == 0).sum() + return float(((ld < tau_strict) & (c == 0)).sum() / max(w, 1)) +lines += ["", "## Bootstrap CIs (HRM @26040, n=8192, strict band)", + f"- strict settled-wrong fraction of failures: observed {21/3894:.4f}, bootstrap 95% CI " + f"{boot_ci(strictB_frac, len(c_h))}", + f"- AUC(-lam1->correct) = {auc_rank(-l1_h, c_h):.4f}, bootstrap 95% CI " + f"{boot_ci(lambda i: auc_rank(-l1_h[i], c_h[i]), len(c_h))}"] + +# ---------- (2) settling-criterion robustness ---------- +lines += ["", "## Settling-criterion robustness (B-cell counts under alternative drift definitions)"] +for tag, d, c in [("TRM official n=2048", trm, c_t), ("HRM n=8192", hrm, c_h)]: + row = [tag] + for key, nm in [("drift_zH", "zH"), ("drift_zL", "zL")]: + ld = late(d, key); tau = otsu(ld); conv = ld < tau + row.append(f"{nm}: B={int((conv & (c==0)).sum())}/A={int((conv & (c==1)).sum())} (tau={tau:.2f})") + comb = np.log10(np.clip(np.sqrt(d["drift_zH"][:, -4:].mean(1) ** 2 + d["drift_zL"][:, -4:].mean(1) ** 2), 1e-12, None)) + tau = otsu(comb); conv = comb < tau + row.append(f"combined: B={int((conv & (c==0)).sum())}/A={int((conv & (c==1)).sum())} (tau={tau:.2f})") + lines.append("- " + " | ".join(row)) + +# ---------- (3) TRM official multi4 vs baseline (matched pipeline) ---------- +lines += ["", "## TRM official-pipeline multi4 vs baseline (matched objective, n=512 each)"] +spec = FLOSS / "official_gbs768_spectrum" +for nm, f in [("baseline @58590", spec / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"), + ("multi4 @35805 (best)", spec / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"), + ("multi4 @65100 (final)", spec / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz")]: + if not f.exists(): + lines.append(f"- {nm}: npz missing"); continue + d = np.load(f); c = d["exact_correct"].astype(int); l1 = d["lyap_spec"][:, 0] + ld = late(d); tau = otsu(ld); conv = ld < tau + A = conv & (c == 1); B = conv & (c == 0); C = (~conv) & (c == 1); D = (~conv) & (c == 0) + lines.append(f"- {nm}: acc={c.mean():.3f}; A/B/C/D={int(A.sum())}/{int(B.sum())}/{int(C.sum())}/{int(D.sum())}; " + f"fD={float(D.mean()):.3f}; lam1(D)={np.median(l1[D]) if D.sum()>0 else float('nan'):+.4f}; " + f"lam1(A)={np.median(l1[A]) if A.sum()>0 else float('nan'):+.4f}") + +# ---------- (4) provenance note ---------- +lines += ["", "## hrm_multi4 provenance (E6a)", + "- diag_hrm_multi4_step_{20832,23436,26040}_512.npz step grid matches HRM pretrain numbering;", + " multi4_eval_compare/logs should contain the eval invocations — checked manually below.", + "- ACTION: if the hrm_multi4 run is pretrain-pipeline (ACT-streaming + perturbation), then the", + " May-28 multi4 vs righteous baseline comparison IS matched-pipeline and Sec 3.4's caveat is", + " narrower than written; step9 E-vs-F pair (queued) covers the fixed-unroll objective regardless."] + +OUT.mkdir(exist_ok=True) +(OUT / "phase1_e1.md").write_text("\n".join(lines)) +print("\n".join(lines)) -- cgit v1.2.3