summaryrefslogtreecommitdiff
path: root/research/flossing/analysis_2x2/offline_followups.py
blob: 8101acb745d4900ea58fe1b6126cf32f432feb0c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""Offline follow-ups to the 2x2 analysis (no GPU):

1. Residual outcome signal within the unsettled stratum (HRM diag_8k primary,
   TRM official @58590 secondary): per-cell drift profiles over the 16 ACT steps,
   end-of-window drift slope, q_halt trajectories, halted_at, lambda spectra with a
   STRICT in-band threshold, and per-drift-decile AUC(lambda1 -> correct) within the
   unsettled stratum (does lambda1 add signal beyond drift level?).
2. Per-example profile of the strict-band settled-but-wrong examples (HRM, n~21).
3. Difficulty control: #givens per puzzle (input tokens != 1) joined via idx;
   lambda1 ~ givens rank correlation overall/within outcome, and per-givens-bin
   AUC(-lambda1 -> correct).

Observational only. Outputs to analysis_2x2/offline_followups/.
"""
from __future__ import annotations

from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

HERE = Path(__file__).resolve().parent
FLOSS = HERE.parent
OUT = HERE / "offline_followups"
OUT.mkdir(exist_ok=True)

DATA_TEST_INPUTS = Path("/home/yurenh2/rrm/data/sudoku-extreme-1k-aug-1000/test/all__inputs.npy")

CELL_COLORS = {"A": "tab:green", "B": "tab:orange", "C": "tab:blue", "D": "tab:red"}


def auc_rank(score: np.ndarray, label: np.ndarray) -> float:
    pos, neg = score[label == 1], score[label == 0]
    if len(pos) == 0 or len(neg) == 0:
        return float("nan")
    allv = np.concatenate([pos, neg])
    order = np.argsort(allv, kind="mergesort")
    ranks = np.empty(len(allv)); ranks[order] = np.arange(1, len(allv) + 1)
    sv = allv[order]; i = 0
    while i < len(sv):
        j = i
        while j + 1 < len(sv) and sv[j + 1] == sv[i]:
            j += 1
        if j > i:
            ranks[order[i:j + 1]] = ranks[order[i:j + 1]].mean()
        i = j + 1
    return float((ranks[:len(pos)].sum() - len(pos) * (len(pos) + 1) / 2) / (len(pos) * len(neg)))


def spearman(a: np.ndarray, b: np.ndarray) -> float:
    ra = np.argsort(np.argsort(a)).astype(float)
    rb = np.argsort(np.argsort(b)).astype(float)
    return float(np.corrcoef(ra, rb)[0, 1])


def load(npz_path: Path, strict_pct: float):
    d = np.load(npz_path)
    out = {k: d[k] for k in d.files}
    out["logd_late"] = np.log10(np.clip(out["drift_zH"][:, -4:].mean(1), 1e-12, None))
    out["tau_strict"] = float(np.percentile(out["logd_late"], strict_pct))
    conv = out["logd_late"] < out["tau_strict"]
    c = out["exact_correct"].astype(int)
    out["cells"] = {
        "A": conv & (c == 1), "B": conv & (c == 0),
        "C": (~conv) & (c == 1), "D": (~conv) & (c == 0),
    }
    return out


def givens_for(idx: np.ndarray) -> np.ndarray:
    inputs = np.load(DATA_TEST_INPUTS, mmap_mode="r")
    return np.array([(inputs[i] != 1).sum() for i in idx])


def drift_profiles_fig(ds, tag, lines):
    fig, axes = plt.subplots(1, 2, figsize=(11, 4))
    steps = np.arange(1, ds["drift_zH"].shape[1] + 1)
    for nm, m in ds["cells"].items():
        if m.sum() < 3:
            continue
        med = np.median(ds["drift_zH"][m], axis=0)
        q1 = np.percentile(ds["drift_zH"][m], 25, axis=0)
        q3 = np.percentile(ds["drift_zH"][m], 75, axis=0)
        axes[0].plot(steps, med, "o-", ms=3, color=CELL_COLORS[nm], label=f"{nm} (n={int(m.sum())})")
        axes[0].fill_between(steps, q1, q3, color=CELL_COLORS[nm], alpha=0.15)
        qm = np.median(ds["q_halt"][m], axis=0)
        axes[1].plot(steps, qm, "o-", ms=3, color=CELL_COLORS[nm], label=nm)
    axes[0].set_yscale("log"); axes[0].set_xlabel("ACT step"); axes[0].set_ylabel("drift_zH (median, IQR)")
    axes[0].legend(fontsize=8); axes[0].set_title(f"{tag}: drift profiles per cell")
    axes[1].set_xlabel("ACT step"); axes[1].set_ylabel("q_halt (median)"); axes[1].axhline(0, color="gray", lw=0.6)
    axes[1].set_title(f"{tag}: q_halt per cell"); axes[1].legend(fontsize=8)
    fig.tight_layout(); fig.savefig(OUT / f"fig_{tag}_profiles.png", dpi=150); plt.close(fig)

    # end-of-window slope: log10 mean(drift[13:16]) - log10 mean(drift[9:12])
    slope = (np.log10(np.clip(ds["drift_zH"][:, 12:16].mean(1), 1e-12, None))
             - np.log10(np.clip(ds["drift_zH"][:, 8:12].mean(1), 1e-12, None)))
    lines.append(f"\n### {tag}: end-of-window drift slope (log10 steps13-16 vs 9-12; <0 = still descending)")
    for nm, m in ds["cells"].items():
        if m.sum() == 0:
            lines.append(f"- {nm}: n=0")
            continue
        lines.append(f"- {nm}: n={int(m.sum())}, slope median {np.median(slope[m]):+.4f}, "
                     f"IQR [{np.percentile(slope[m],25):+.4f}, {np.percentile(slope[m],75):+.4f}], "
                     f"frac still descending (<-0.01): {float((slope[m] < -0.01).mean()):.2f}")
    return slope


def main() -> None:
    lines = ["# Offline follow-ups (no GPU) — 2026-06-11",
             "",
             "Strict in-band thresholds: HRM pct45 of pooled log10 late-drift; TRM pct60 (band edge; B=0 regardless).",
             "All numbers observational; within-dataset comparisons only."]

    # ---------- HRM diag_8k ----------
    hrm = load(FLOSS / "diag_8k.npz", strict_pct=45)
    tag = "hrm26040_n8192_strict"
    lines.append(f"\n## HRM @26040 (n=8192), strict tau(log10)={hrm['tau_strict']:.4f}")
    lines.append("| cell | n | lam1 med | lam8 med | token_acc med | halted_at med | q_halt_final med | givens med |")
    lines.append("|---|---|---|---|---|---|---|---|")
    g_hrm = givens_for(hrm["idx"])
    for nm, m in hrm["cells"].items():
        if m.sum() == 0:
            lines.append(f"| {nm} | 0 | | | | | | |")
            continue
        lines.append(
            f"| {nm} | {int(m.sum())} | {np.median(hrm['lyap_spec'][m,0]):+.4f} | {np.median(hrm['lyap_spec'][m,-1]):+.4f} "
            f"| {np.median(hrm['token_acc'][m]):.3f} | {np.median(hrm['halted_at'][m]):.0f} "
            f"| {np.median(hrm['q_halt'][m,-1]):+.2f} | {np.median(g_hrm[m]):.0f} |")

    slope = drift_profiles_fig(hrm, tag, lines)

    # residual signal within unsettled stratum: per-drift-decile AUC
    uns = ~(hrm["cells"]["A"] | hrm["cells"]["B"])
    c = hrm["exact_correct"].astype(int)
    lines.append("\n### HRM unsettled stratum: AUC(-lam1 -> correct) per log-drift decile")
    lines.append("| decile | drift range (log10) | n | n_correct | AUC |")
    lines.append("|---|---|---|---|---|")
    ld_u, l1_u, c_u = hrm["logd_late"][uns], hrm["lyap_spec"][uns, 0], c[uns]
    qs = np.percentile(ld_u, np.arange(0, 101, 10))
    aucs, ws = [], []
    for i in range(10):
        m = (ld_u >= qs[i]) & (ld_u <= qs[i + 1] if i == 9 else ld_u < qs[i + 1])
        a = auc_rank(-l1_u[m], c_u[m])
        if not np.isnan(a) and c_u[m].sum() >= 5:
            aucs.append(a); ws.append(m.sum())
        lines.append(f"| {i+1} | [{qs[i]:.2f}, {qs[i+1]:.2f}] | {int(m.sum())} | {int(c_u[m].sum())} | "
                     f"{a:.3f} |" if not np.isnan(a) else f"| {i+1} | [{qs[i]:.2f}, {qs[i+1]:.2f}] | {int(m.sum())} | {int(c_u[m].sum())} | n/a |")
    if aucs:
        lines.append(f"- weighted mean within-decile AUC = {np.average(aucs, weights=ws):.3f} "
                     f"(vs unconditioned within-unsettled AUC {auc_rank(-l1_u, c_u):.3f})")

    # also: does end-slope separate C from D?
    lines.append(f"- AUC(-end_slope -> correct | unsettled) = {auc_rank(-slope[uns], c_u):.3f} "
                 f"(C still-descending fraction vs D, see slope table above)")

    # ---------- strict-B per-example table ----------
    B = hrm["cells"]["B"]
    lines.append(f"\n## HRM strict-band settled-but-wrong examples (n={int(B.sum())})")
    lines.append("| idx | givens | token_acc | lam1 | drift_final | halted_at | q_halt_final |")
    lines.append("|---|---|---|---|---|---|---|")
    bi = np.where(B)[0]
    order = np.argsort(hrm["token_acc"][bi])
    for j in bi[order]:
        lines.append(
            f"| {int(hrm['idx'][j])} | {int(g_hrm[j])} | {hrm['token_acc'][j]:.3f} | {hrm['lyap_spec'][j,0]:+.3f} "
            f"| {hrm['drift_zH'][j,-1]:.3f} | {int(hrm['halted_at'][j])} | {hrm['q_halt'][j,-1]:+.2f} |")
    # B drift profiles vs A band
    fig, ax = plt.subplots(figsize=(6.5, 4))
    steps = np.arange(1, 17)
    A = hrm["cells"]["A"]
    ax.fill_between(steps, np.percentile(hrm["drift_zH"][A], 10, axis=0),
                    np.percentile(hrm["drift_zH"][A], 90, axis=0), color="tab:green", alpha=0.2,
                    label=f"A q10-q90 (n={int(A.sum())})")
    for j in bi:
        ax.plot(steps, hrm["drift_zH"][j], "-", lw=1, alpha=0.8, color="tab:orange")
    ax.set_yscale("log"); ax.set_xlabel("ACT step"); ax.set_ylabel("drift_zH")
    ax.set_title("HRM: strict-B drift profiles vs A band"); ax.legend(fontsize=8)
    fig.tight_layout(); fig.savefig(OUT / "fig_hrm_strictB_profiles.png", dpi=150); plt.close(fig)

    # ---------- difficulty control (HRM) ----------
    l1 = hrm["lyap_spec"][:, 0]
    lines.append("\n## HRM difficulty control (#givens, input tokens != 1)")
    lines.append(f"- givens: min {g_hrm.min()}, median {np.median(g_hrm):.0f}, max {g_hrm.max()}")
    lines.append(f"- Spearman(lam1, givens): overall {spearman(l1, g_hrm):+.3f}; "
                 f"correct-only {spearman(l1[c==1], g_hrm[c==1]):+.3f}; "
                 f"wrong-only {spearman(l1[c==0], g_hrm[c==0]):+.3f}")
    lines.append(f"- Spearman(correct, givens) = {spearman(c.astype(float), g_hrm):+.3f}")
    lines.append("\n| givens bin | n | acc | AUC(-lam1 -> correct) |")
    lines.append("|---|---|---|---|")
    edges = np.unique(np.percentile(g_hrm, [0, 20, 40, 60, 80, 100]))
    bin_aucs, bin_ws = [], []
    for i in range(len(edges) - 1):
        m = (g_hrm >= edges[i]) & (g_hrm <= edges[i + 1] if i == len(edges) - 2 else g_hrm < edges[i + 1])
        a = auc_rank(-l1[m], c[m])
        lines.append(f"| [{edges[i]:.0f}, {edges[i+1]:.0f}] | {int(m.sum())} | {c[m].mean():.3f} | {a:.3f} |")
        if not np.isnan(a):
            bin_aucs.append(a); bin_ws.append(m.sum())
    lines.append(f"- weighted mean within-bin AUC = {np.average(bin_aucs, weights=bin_ws):.3f} (overall 0.984)")

    # ---------- TRM official @58590 (secondary, n=512) ----------
    trm = load(FLOSS / "official_gbs768_spectrum/trm_gbs768_base_step58590_n512_k8_seed20260602.npz", strict_pct=60)
    g_trm = givens_for(trm["idx"])
    ct = trm["exact_correct"].astype(int)
    lines.append(f"\n## TRM official @58590 (n=512), strict tau(log10)={trm['tau_strict']:.4f}")
    lines.append("| cell | n | lam1 med | token_acc med | q_halt_final med | givens med |")
    lines.append("|---|---|---|---|---|---|")
    for nm, m in trm["cells"].items():
        if m.sum() == 0:
            lines.append(f"| {nm} | 0 | | | | |")
            continue
        lines.append(f"| {nm} | {int(m.sum())} | {np.median(trm['lyap_spec'][m,0]):+.4f} "
                     f"| {np.median(trm['token_acc'][m]):.3f} | {np.median(trm['q_halt'][m,-1]):+.2f} "
                     f"| {np.median(g_trm[m]):.0f} |")
    drift_profiles_fig(trm, "trm_official58590_n512_strict", lines)
    l1t = trm["lyap_spec"][:, 0]
    lines.append(f"- Spearman(lam1, givens): overall {spearman(l1t, g_trm):+.3f}; "
                 f"wrong-only {spearman(l1t[ct==0], g_trm[ct==0]):+.3f}")
    lines.append(f"- Spearman(correct, givens) = {spearman(ct.astype(float), g_trm):+.3f}")

    (OUT / "followups.md").write_text("\n".join(lines))
    print("\n".join(lines[:6]))
    print("wrote", OUT / "followups.md")


if __name__ == "__main__":
    main()