diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/bias_var.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/bias_var.py')
| -rw-r--r-- | ep_run/bias_var.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/ep_run/bias_var.py b/ep_run/bias_var.py new file mode 100644 index 0000000..a616ce4 --- /dev/null +++ b/ep_run/bias_var.py @@ -0,0 +1,63 @@ +"""Decisive 'why is EP far at S1' diagnostic: separate estimator BIAS from VARIANCE at the +converged v4b checkpoint. Over N batches compute EP grad, BPTT-400 grad, BPTT-150 control. + mean-cos = mean_b cos(g_EP^b, g_BPTT^b) -> per-step quality (noisy) + cos-means = cos(sum_b g_EP, sum_b g_BPTT) -> if >> mean-cos: errors AVERAGE OUT = VARIANCE + if ~ mean-cos: systematic = BIAS (the real wall) +BPTT-150-vs-400 gives the same two metrics as the slow-mixing horizon baseline.""" +import torch +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories') +M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step +dev = 'cuda' +torch.manual_seed(0) +B, T, C, H = 8, 256, 256, 8 +blk = EQBlock(C, H, 256, T, attn_mode='thick') +blk.qknorm = False; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0; blk._cstep = None +ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt') +for p, w in zip(blk.allp, ck['allp']): + with torch.no_grad(): + p.copy_(w.to(dev)) +print(f"v4b ckpt best {ck['best']:.4f}", flush=True) +groups = {'all': blk.block, 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], + 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], + 'emb': [blk.tok, blk.pos]} +N = 16 +sEP, s400, s150 = {}, {}, {} +cos_b = {k: [] for k in groups} +bctl_b = {k: [] for k in groups} + +def flat(g, ps): + v = [g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None] + return torch.cat(v) if v else None + +def cos(a, b): + return (a @ b / (a.norm() * b.norm() + 1e-12)).item() + +for i in range(N): + idx, y = get_batch('train', B, T) + gE, _ = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120) + g4 = bptt_step(blk, idx, y, 400, 0.1) + g1 = bptt_step(blk, idx, y, 150, 0.1) + for k, ps in groups.items(): + a, b, c = flat(gE, ps), flat(g4, ps), flat(g1, ps) + if a is not None and b is not None: + cos_b[k].append(cos(a, b)) + if c is not None and b is not None: + bctl_b[k].append(cos(c, b)) + for src, acc in ((gE, sEP), (g4, s400), (g1, s150)): + for p in blk.block: + if src.get(id(p)) is not None: + acc[id(p)] = src[id(p)].detach().clone() if id(p) not in acc else acc[id(p)] + src[id(p)].detach() + print(f" batch {i+1}/{N} done", flush=True) + +print(f"\n{'group':>5} {'EP mean-cos':>12} {'EP cos-means':>13} {'BPTT mean-cos':>14} {'BPTT cos-means':>15}") +for k, ps in groups.items(): + mc = sum(cos_b[k]) / len(cos_b[k]) + bmc = sum(bctl_b[k]) / len(bctl_b[k]) + aE, a4, a1 = flat(sEP, ps), flat(s400, ps), flat(s150, ps) + cm = cos(aE, a4) + bcm = cos(a1, a4) + print(f"{k:>5} {mc:>12.3f} {cm:>13.3f} {bmc:>14.3f} {bcm:>15.3f}", flush=True) |
