summaryrefslogtreecommitdiff
path: root/ep_run/bias_var.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/bias_var.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/bias_var.py')
-rw-r--r--ep_run/bias_var.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/ep_run/bias_var.py b/ep_run/bias_var.py
new file mode 100644
index 0000000..a616ce4
--- /dev/null
+++ b/ep_run/bias_var.py
@@ -0,0 +1,63 @@
+"""Decisive 'why is EP far at S1' diagnostic: separate estimator BIAS from VARIANCE at the
+converged v4b checkpoint. Over N batches compute EP grad, BPTT-400 grad, BPTT-150 control.
+ mean-cos = mean_b cos(g_EP^b, g_BPTT^b) -> per-step quality (noisy)
+ cos-means = cos(sum_b g_EP, sum_b g_BPTT) -> if >> mean-cos: errors AVERAGE OUT = VARIANCE
+ if ~ mean-cos: systematic = BIAS (the real wall)
+BPTT-150-vs-400 gives the same two metrics as the slow-mixing horizon baseline."""
+import torch
+import lt_ep_train as M
+from pathlib import Path
+import pickle
+M.DD = Path('/tmp/lt_ep/data/tinystories')
+M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size']
+from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step
+dev = 'cuda'
+torch.manual_seed(0)
+B, T, C, H = 8, 256, 256, 8
+blk = EQBlock(C, H, 256, T, attn_mode='thick')
+blk.qknorm = False; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0; blk._cstep = None
+ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt')
+for p, w in zip(blk.allp, ck['allp']):
+ with torch.no_grad():
+ p.copy_(w.to(dev))
+print(f"v4b ckpt best {ck['best']:.4f}", flush=True)
+groups = {'all': blk.block, 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
+ 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b],
+ 'emb': [blk.tok, blk.pos]}
+N = 16
+sEP, s400, s150 = {}, {}, {}
+cos_b = {k: [] for k in groups}
+bctl_b = {k: [] for k in groups}
+
+def flat(g, ps):
+ v = [g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None]
+ return torch.cat(v) if v else None
+
+def cos(a, b):
+ return (a @ b / (a.norm() * b.norm() + 1e-12)).item()
+
+for i in range(N):
+ idx, y = get_batch('train', B, T)
+ gE, _ = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120)
+ g4 = bptt_step(blk, idx, y, 400, 0.1)
+ g1 = bptt_step(blk, idx, y, 150, 0.1)
+ for k, ps in groups.items():
+ a, b, c = flat(gE, ps), flat(g4, ps), flat(g1, ps)
+ if a is not None and b is not None:
+ cos_b[k].append(cos(a, b))
+ if c is not None and b is not None:
+ bctl_b[k].append(cos(c, b))
+ for src, acc in ((gE, sEP), (g4, s400), (g1, s150)):
+ for p in blk.block:
+ if src.get(id(p)) is not None:
+ acc[id(p)] = src[id(p)].detach().clone() if id(p) not in acc else acc[id(p)] + src[id(p)].detach()
+ print(f" batch {i+1}/{N} done", flush=True)
+
+print(f"\n{'group':>5} {'EP mean-cos':>12} {'EP cos-means':>13} {'BPTT mean-cos':>14} {'BPTT cos-means':>15}")
+for k, ps in groups.items():
+ mc = sum(cos_b[k]) / len(cos_b[k])
+ bmc = sum(bctl_b[k]) / len(bctl_b[k])
+ aE, a4, a1 = flat(sEP, ps), flat(s400, ps), flat(s150, ps)
+ cm = cos(aE, a4)
+ bcm = cos(a1, a4)
+ print(f"{k:>5} {mc:>12.3f} {cm:>13.3f} {bmc:>14.3f} {bcm:>15.3f}", flush=True)