summaryrefslogtreecommitdiff
path: root/research/flossing/analyze_halt_bucket.py
blob: 2ad80fc9908a1b81ffe6bc814d7487a5491b76dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Validate codex's λ*(h, β) = (1/h) log β framework using halted_at buckets.

If the framework is right, then across halt buckets:
- success samples should all satisfy λ_1 × h_micro × halted_at ≈ log(β_target) (a constant ≪ 0)
- equivalently: success λ_1 should scale as ~1/halted_at

We use halted_at as a proxy for "effective computation budget" (h in codex's notation).
For each ACT step we have cycles_per_act micro-Lyapunov-steps:
  HRM (H=2, L=2): 2*(2+1) = 6
  TRM (H=3, L=6): 3*(6+1) = 21
"""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

ROOT = "/home/yurenh2/rrm/research/flossing"

cases = {
    "HRM_step26040": dict(npz=f"{ROOT}/diag_joint_1k.npz", cycles_per_act=6),
    "TRM_step13020": dict(npz=f"{ROOT}/diag_trm_step13020_512.npz", cycles_per_act=21),
}

fig, axes = plt.subplots(2, 3, figsize=(15, 9))

for row, (name, meta) in enumerate(cases.items()):
    d = np.load(meta["npz"])
    cpa = meta["cycles_per_act"]
    lam = d["lyap_spec"][:, 0]                          # top-1 Lyap per sample (per micro step)
    succ = d["exact_correct"] > 0.5
    halted = d["halted_at"].copy()
    halted[halted == 0] = 16                            # never-halt → used full 16 ACT steps

    h_micro = halted * cpa                              # effective micro-step horizon
    logbeta = lam * h_micro                             # log of cumulative decay over h_micro
    beta = np.exp(np.clip(logbeta, -20, 20))            # implied β

    print(f"\n=== {name} ===")
    print(f"  N={len(lam)}  acc={succ.mean():.3f}  cycles/ACT={cpa}")
    print(f"  λ_1 mean:     succ={lam[succ].mean():+.4f}   fail={lam[~succ].mean():+.4f}")
    print(f"  halted_at:    succ={halted[succ].mean():.2f}    fail={halted[~succ].mean():.2f}")
    print(f"  log β implied succ={logbeta[succ].mean():+.3f}  fail={logbeta[~succ].mean():+.3f}")
    print(f"  β implied:    succ={beta[succ].mean():.3f}   fail={beta[~succ].mean():.3f}")

    # By halt bucket
    print(f"  bucket    n_succ n_fail   λ_succ      λ_fail      logβ_succ   logβ_fail")
    buckets = sorted(set(halted.tolist()))
    for b in buckets:
        ms = (halted == b) & succ
        mf = (halted == b) & ~succ
        if ms.sum() + mf.sum() < 5: continue
        lam_s = lam[ms].mean() if ms.sum() > 0 else np.nan
        lam_f = lam[mf].mean() if mf.sum() > 0 else np.nan
        lb_s  = lam_s * b * cpa
        lb_f  = lam_f * b * cpa
        print(f"  h={b:>3}      {ms.sum():>4}  {mf.sum():>4}    "
              f"{lam_s:+7.4f}    {lam_f:+7.4f}    {lb_s:+8.3f}    {lb_f:+8.3f}")

    # ----- Panel A: λ vs halted_at, succ vs fail -----
    ax = axes[row, 0]
    ax.scatter(halted[succ] + np.random.uniform(-0.15, 0.15, succ.sum()), lam[succ],
               s=5, alpha=0.35, c="C2", label=f"succ (n={succ.sum()})")
    ax.scatter(halted[~succ] + np.random.uniform(-0.15, 0.15, (~succ).sum()), lam[~succ],
               s=5, alpha=0.35, c="C3", label=f"fail (n={(~succ).sum()})")
    # Per-bucket mean
    means_s, means_f, bs = [], [], []
    for b in buckets:
        ms = (halted == b) & succ
        mf = (halted == b) & ~succ
        if ms.sum() >= 3: means_s.append((b, lam[ms].mean()))
        if mf.sum() >= 3: means_f.append((b, lam[mf].mean()))
    if means_s:
        xs, ys = zip(*means_s); ax.plot(xs, ys, "C2o-", lw=2, ms=8, label="succ mean")
    if means_f:
        xf, yf = zip(*means_f); ax.plot(xf, yf, "C3o-", lw=2, ms=8, label="fail mean")
    # 1/h reference: codex prediction λ ∝ 1/h
    if means_s:
        h_ref = np.array([b for b, _ in means_s])
        # fit constant = mean of λ*h*cpa over succ
        c = (lam[succ] * halted[succ] * cpa).mean()
        ref = c / (h_ref * cpa)
        ax.plot(h_ref, ref, "k--", lw=1, alpha=0.5, label=f"λ=1/h·log β  (log β={c:.2f})")
    ax.axhline(0, color="k", lw=0.5, ls=":")
    ax.set_xlabel("halted_at (ACT steps)"); ax.set_ylabel("λ_1 (per micro-step)")
    ax.set_title(f"{name}: λ_1 vs halt step")
    ax.legend(fontsize=7); ax.grid(alpha=0.3)

    # ----- Panel B: log β = λ × h_micro by halt bucket -----
    ax = axes[row, 1]
    ax.scatter(halted[succ] + np.random.uniform(-0.15, 0.15, succ.sum()), logbeta[succ],
               s=5, alpha=0.35, c="C2", label="succ")
    ax.scatter(halted[~succ] + np.random.uniform(-0.15, 0.15, (~succ).sum()), logbeta[~succ],
               s=5, alpha=0.35, c="C3", label="fail")
    if means_s:
        xs = [b for b, _ in means_s]
        ys = [lam[(halted==b)&succ].mean() * b * cpa for b in xs]
        ax.plot(xs, ys, "C2o-", lw=2, ms=8, label="succ mean")
    if means_f:
        xf = [b for b, _ in means_f]
        yf = [lam[(halted==b)&~succ].mean() * b * cpa for b in xf]
        ax.plot(xf, yf, "C3o-", lw=2, ms=8, label="fail mean")
    ax.axhline(0, color="k", lw=0.5, ls=":")
    ax.set_xlabel("halted_at"); ax.set_ylabel("log β = λ_1 × halt × cycles/ACT")
    ax.set_title(f"{name}: implied log β (constant ⇒ 1/h scaling holds)")
    ax.legend(fontsize=8); ax.grid(alpha=0.3)

    # ----- Panel C: histogram of implied β by succ/fail -----
    ax = axes[row, 2]
    bins = np.linspace(-6, 4, 50)
    ax.hist(logbeta[succ], bins=bins, alpha=0.6, color="C2", label=f"succ μ={logbeta[succ].mean():+.2f}")
    ax.hist(logbeta[~succ], bins=bins, alpha=0.6, color="C3", label=f"fail μ={logbeta[~succ].mean():+.2f}")
    ax.axvline(0, color="k", lw=0.5, ls=":")
    ax.set_xlabel("log β implied"); ax.set_ylabel("count")
    ax.set_title(f"{name}: log β distribution")
    ax.legend(fontsize=8); ax.grid(alpha=0.3)

fig.suptitle("Codex's λ*(h, β) = (1/h) log β prediction: across halt buckets, "
             "λ should scale as 1/h with constant log β", fontsize=11)
fig.tight_layout()
out = f"{ROOT}/plots_halt_bucket.png"
fig.savefig(out, dpi=130)
print(f"\n→ {out}")