diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/adaptive_eps_calib.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/adaptive_eps_calib.py')
| -rw-r--r-- | ep_run/adaptive_eps_calib.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/ep_run/adaptive_eps_calib.py b/ep_run/adaptive_eps_calib.py new file mode 100644 index 0000000..475e296 --- /dev/null +++ b/ep_run/adaptive_eps_calib.py @@ -0,0 +1,40 @@ +import torch, pickle, math +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=8; T=256 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev) +def mkblk(): + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + return blk +def force_g(blk,z,xin): + f=blk.force(z,xin).detach(); return f, f.norm().item()/(z.norm().item()+1e-9) +def run(adaptive, e0=0.1, emin=0.005, emax=0.1, down=0.5, up=1.1, theta=0.98, N=10000): + blk=mkblk() + with torch.no_grad(): + xin=blk.embed(idx).detach(); z=xin.clone(); eps=e0; prev=None; gs=[]; eh=[] + for t in range(N): + f,g=force_g(blk,z,xin); gs.append(g); eh.append(eps) + if adaptive and prev is not None: + if g>theta*prev: eps=max(emin,eps*down) + elif g<0.9*prev: eps=min(emax,eps*up) + prev=g; z=z+eps*f + tail=gs[-500:] + return dict(gmin=min(tail), gmean=sum(tail)/len(tail), avg_eps=sum(eh)/len(eh), final_eps=eh[-1]) +print("=== adaptive-eps controller CALIBRATION on s3200 (cycling op) ===") +print("ground truth: fixed eps=0.1 cycles (g~0.23); fixed eps=0.01 converges (g~0.09)") +print("-- benchmarks (fixed eps) --") +for e in (0.1, 0.01): + r=run(False, e0=e, emin=e, emax=e); print(f" fixed eps={e}: g_tail[min={r['gmin']:.4f} mean={r['gmean']:.4f}]") +print("-- adaptive configs (want: g <= 0.01-benchmark, avg_eps as HIGH as possible = fewer effective steps) --") +for name,kw in [("C1 cons", dict(down=0.5,up=1.1,theta=0.98)), + ("C2 mod", dict(down=0.7,up=1.2,theta=0.98)), + ("C3 caut", dict(down=0.5,up=1.05,theta=0.99)), + ("C4 aggr", dict(down=0.6,up=1.3,theta=0.95))]: + r=run(True, **kw); + print(f" {name} {kw}: g_tail[min={r['gmin']:.4f} mean={r['gmean']:.4f}] avg_eps={r['avg_eps']:.4f} final_eps={r['final_eps']:.4f}") +print("=== DONE ===") |
