1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
"""E1 offline batch (experiment_framework.md):
(1) bootstrap CIs for headline quantities; (2) settling-criterion robustness (z_L / combined);
(3) TRM official multi4 vs baseline matched-pipeline 2x2; (4) provenance note for hrm_multi4.
Outputs: offline_followups/phase1_e1.md
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
HERE = Path(__file__).resolve().parent
FLOSS = HERE.parent
OUT = HERE / "offline_followups"
RNG = np.random.default_rng(0)
def auc_rank(score, label):
pos, neg = score[label == 1], score[label == 0]
if len(pos) == 0 or len(neg) == 0:
return float("nan")
allv = np.concatenate([pos, neg])
order = np.argsort(allv, kind="mergesort")
ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1)
sv = allv[order]; i = 0
while i < len(sv):
j = i
while j + 1 < len(sv) and sv[j + 1] == sv[i]:
j += 1
if j > i:
ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean()
i = j + 1
return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg)))
def boot_ci(stat_fn, n, B=10000):
vals = []
for _ in range(B):
idx = RNG.integers(0, n, n)
v = stat_fn(idx)
if not np.isnan(v):
vals.append(v)
return float(np.percentile(vals, 2.5)), float(np.percentile(vals, 97.5))
def late(d, key="drift_zH"):
return np.log10(np.clip(d[key][:, -4:].mean(1), 1e-12, None))
def otsu(x, nbins=256):
h, e = np.histogram(x, bins=nbins); h = h.astype(float)
c = (e[:-1] + e[1:]) / 2
p = h / h.sum(); om = np.cumsum(p); mu = np.cumsum(p * c); mt = mu[-1]
den = om * (1 - om); den[den <= 0] = np.nan
return float(c[np.nanargmax((mt * om - mu) ** 2 / den)])
lines = ["# E1 offline batch — bootstrap CIs, settling robustness, TRM multi4 pair", ""]
# ---------- (1) bootstrap CIs ----------
trm = np.load(FLOSS / "analysis_2x2/retest/trm_gbs768_step58590_full_n2048.npz")
c_t = trm["exact_correct"].astype(int); l1_t = trm["lyap_spec"][:, 0]; ld_t = late(trm)
n_wrong = int((c_t == 0).sum())
cp_upper = 1 - 0.05 ** (1 / n_wrong) # Clopper-Pearson upper for 0 events
lines += ["## Bootstrap / exact CIs (TRM official @58590, n=2048)",
f"- settled-wrong fraction: observed 0/{n_wrong}; exact 95% upper bound {cp_upper:.4f} "
f"({cp_upper*100:.2f}% of failures)",
f"- AUC(-lam1->correct) = {auc_rank(-l1_t, c_t):.4f}, bootstrap 95% CI "
f"{boot_ci(lambda i: auc_rank(-l1_t[i], c_t[i]), len(c_t))}",
f"- lam1(wrong) median 95% CI {boot_ci(lambda i: float(np.median(l1_t[i][c_t[i]==0])), len(c_t))}",
f"- lam1(correct) median 95% CI {boot_ci(lambda i: float(np.median(l1_t[i][c_t[i]==1])), len(c_t))}"]
hrm = np.load(FLOSS / "diag_8k.npz")
c_h = hrm["exact_correct"].astype(int); l1_h = hrm["lyap_spec"][:, 0]; ld_h = late(hrm)
tau_strict = float(np.percentile(ld_h, 45))
def strictB_frac(idx):
c, ld = c_h[idx], ld_h[idx]
w = (c == 0).sum()
return float(((ld < tau_strict) & (c == 0)).sum() / max(w, 1))
lines += ["", "## Bootstrap CIs (HRM @26040, n=8192, strict band)",
f"- strict settled-wrong fraction of failures: observed {21/3894:.4f}, bootstrap 95% CI "
f"{boot_ci(strictB_frac, len(c_h))}",
f"- AUC(-lam1->correct) = {auc_rank(-l1_h, c_h):.4f}, bootstrap 95% CI "
f"{boot_ci(lambda i: auc_rank(-l1_h[i], c_h[i]), len(c_h))}"]
# ---------- (2) settling-criterion robustness ----------
lines += ["", "## Settling-criterion robustness (B-cell counts under alternative drift definitions)"]
for tag, d, c in [("TRM official n=2048", trm, c_t), ("HRM n=8192", hrm, c_h)]:
row = [tag]
for key, nm in [("drift_zH", "zH"), ("drift_zL", "zL")]:
ld = late(d, key); tau = otsu(ld); conv = ld < tau
row.append(f"{nm}: B={int((conv & (c==0)).sum())}/A={int((conv & (c==1)).sum())} (tau={tau:.2f})")
comb = np.log10(np.clip(np.sqrt(d["drift_zH"][:, -4:].mean(1) ** 2 + d["drift_zL"][:, -4:].mean(1) ** 2), 1e-12, None))
tau = otsu(comb); conv = comb < tau
row.append(f"combined: B={int((conv & (c==0)).sum())}/A={int((conv & (c==1)).sum())} (tau={tau:.2f})")
lines.append("- " + " | ".join(row))
# ---------- (3) TRM official multi4 vs baseline (matched pipeline) ----------
lines += ["", "## TRM official-pipeline multi4 vs baseline (matched objective, n=512 each)"]
spec = FLOSS / "official_gbs768_spectrum"
for nm, f in [("baseline @58590", spec / "trm_gbs768_base_step58590_n512_k8_seed20260602.npz"),
("multi4 @35805 (best)", spec / "trm_gbs768_multi4_step35805_n512_k8_seed20260602.npz"),
("multi4 @65100 (final)", spec / "trm_gbs768_multi4_step65100_n512_k8_seed20260602.npz")]:
if not f.exists():
lines.append(f"- {nm}: npz missing"); continue
d = np.load(f); c = d["exact_correct"].astype(int); l1 = d["lyap_spec"][:, 0]
ld = late(d); tau = otsu(ld); conv = ld < tau
A = conv & (c == 1); B = conv & (c == 0); C = (~conv) & (c == 1); D = (~conv) & (c == 0)
lines.append(f"- {nm}: acc={c.mean():.3f}; A/B/C/D={int(A.sum())}/{int(B.sum())}/{int(C.sum())}/{int(D.sum())}; "
f"fD={float(D.mean()):.3f}; lam1(D)={np.median(l1[D]) if D.sum()>0 else float('nan'):+.4f}; "
f"lam1(A)={np.median(l1[A]) if A.sum()>0 else float('nan'):+.4f}")
# ---------- (4) provenance note ----------
lines += ["", "## hrm_multi4 provenance (E6a)",
"- diag_hrm_multi4_step_{20832,23436,26040}_512.npz step grid matches HRM pretrain numbering;",
" multi4_eval_compare/logs should contain the eval invocations — checked manually below.",
"- ACTION: if the hrm_multi4 run is pretrain-pipeline (ACT-streaming + perturbation), then the",
" May-28 multi4 vs righteous baseline comparison IS matched-pipeline and Sec 3.4's caveat is",
" narrower than written; step9 E-vs-F pair (queued) covers the fixed-unroll objective regardless."]
OUT.mkdir(exist_ok=True)
(OUT / "phase1_e1.md").write_text("\n".join(lines))
print("\n".join(lines))
|