summaryrefslogtreecommitdiff
path: root/ep_run/gcalib.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/gcalib.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/gcalib.py')
-rw-r--r--ep_run/gcalib.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/ep_run/gcalib.py b/ep_run/gcalib.py
new file mode 100644
index 0000000..85495c0
--- /dev/null
+++ b/ep_run/gcalib.py
@@ -0,0 +1,38 @@
+"""EP lr theory, step 1: measure k = |g_EP|/|g_BPTT| per param group at a realistic operating point.
+Native reference is BPTT (Ernoult: EP=BPTT as beta->0, converged) — NOT BP. lr_EP = lr_BPTT / k.
+Report magnitude ratio AND cosine (direction) per group so we separate scale (k) from alignment."""
+import torch
+import lt_ep_train as M
+from pathlib import Path
+import pickle
+M.DD = Path('/tmp/lt_ep/data/tinystories_bpe'); M.vocab = pickle.load(open(M.DD/'meta.pkl','rb'))['vocab_size']
+from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step
+torch.manual_seed(0)
+C,H,T,B = 512, 16, 256, 16
+blk = EQBlock(C,H,256,T,attn_mode='thick'); blk.qknorm=True; blk.track=False; blk.li_avg=0; blk.navg=1; blk.fnoise=0; blk.nbrake=0; blk._cstep=None
+with torch.no_grad(): blk.WO.mul_(0.1); blk.pj.mul_(0.1)
+opt = torch.optim.AdamW(blk.allp, lr=5e-4, weight_decay=1e-4)
+for _ in range(300): # pretrain to a realistic operating point (BPTT)
+ idx,y = get_batch('train',B,T); g = bptt_step(blk,idx,y,150,0.1)
+ opt.zero_grad(set_to_none=True)
+ for p in blk.allp: p.grad = g.get(id(p))
+ torch.nn.utils.clip_grad_norm_(blk.allp,5.0); opt.step()
+print("pretrained 300 BPTT steps (C=512). k=|g_EP|/|g_BPTT|, cos=direction:", flush=True)
+groups = {'all':blk.block,'attn':[blk.WQ,blk.WK,blk.WV,blk.WO],'ffn':[blk.fc,blk.fcb,blk.pj,blk.pjb],
+ 'ln':[blk.ln1g,blk.ln1b,blk.ln2g,blk.ln2b],'emb':[blk.tok,blk.pos]}
+def cat(g,ps):
+ v=[g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None]; return torch.cat(v) if v else None
+import numpy as np
+acc={k:[] for k in groups}; accc={k:[] for k in groups}
+for _ in range(6):
+ idx,y = get_batch('train',B,T)
+ gE,_ = ep_step(blk,idx,y,150,20,0.1,0.02,0.0,holo=2,hr=0.02,t1max=500,res_est=1e-4,t2sel=120)
+ gB = bptt_step(blk,idx,y,400,0.1)
+ for k,ps in groups.items():
+ a,b = cat(gE,ps),cat(gB,ps)
+ if a is not None and b is not None:
+ acc[k].append((a.norm()/(b.norm()+1e-12)).item())
+ accc[k].append((a@b/(a.norm()*b.norm()+1e-12)).item())
+print(f"{'group':>5} {'k=|gEP|/|gBPTT|':>16} {'cos':>6} -> lr_EP = lr_BPTT / k")
+for k in groups:
+ print(f"{k:>5} {np.mean(acc[k]):>16.3f} {np.mean(accc[k]):>6.3f}", flush=True)