summaryrefslogtreecommitdiff
path: root/ep_run/profile_ep.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/profile_ep.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/profile_ep.py')
-rw-r--r--ep_run/profile_ep.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/ep_run/profile_ep.py b/ep_run/profile_ep.py
new file mode 100644
index 0000000..f76c4f2
--- /dev/null
+++ b/ep_run/profile_ep.py
@@ -0,0 +1,40 @@
+import torch, time, math
+import lt_ep_train as LT
+torch.manual_seed(0)
+def mk():
+ blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick')
+ blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
+ return blk
+def T(fn,reps=3,warm=1):
+ try:
+ torch.cuda.empty_cache()
+ for _ in range(warm): fn()
+ torch.cuda.synchronize(); t0=time.time()
+ for _ in range(reps): fn()
+ torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000)
+ except Exception as e:
+ return f"ERR {type(e).__name__}: {str(e)[:70]}"
+blk=mk(); idx,y=LT.get_batch('train',24,256)
+base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2)
+S=lambda **kw: (lambda: LT.ep_step(blk,idx,y,**{**base,**kw}))
+print("=== full + component toggles (ms/step, B=24, C512) ===",flush=True)
+full=T(S()); print(f"FULL ep_step: {full}",flush=True)
+for n,kw in [("-jacreg",dict(jacreg=0)),("-resreg",dict(resreg=0)),("-t1max(no refine)",dict(t1max=0)),
+ ("t2sel=80",dict(t2sel=80)),("t2sel=40",dict(t2sel=40)),("plain nudge holo=0 T2=20",dict(holo=0,t2sel=0))]:
+ print(f" {n}: {T(S(**kw))}",flush=True)
+xin=blk.embed(idx).detach()
+print(f" free relax T1=150 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True)
+print(f" free relax T1=300 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,300,0.1))}",flush=True)
+print("=== batch sweep (full) ===",flush=True)
+for B in [8,24,48]:
+ ib,yb=LT.get_batch('train',B,256); t=T(lambda: LT.ep_step(blk,ib,yb,**base))
+ print(f" B={B}: {t} ms"+(f" ({t/B:.1f}/sample)" if isinstance(t,(int,float)) else ""),flush=True)
+print("=== compile free relax ===",flush=True)
+try:
+ blk._cstep=torch.compile(lambda z,xn: z+0.1*blk.tforce(z,xn))
+ print(f" free relax T1=150 COMPILED: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True); del blk._cstep
+except Exception as e: print(f" compile ERR {e}",flush=True)
+print("=== bf16 full ===",flush=True)
+def bf():
+ with torch.autocast('cuda',dtype=torch.bfloat16): LT.ep_step(blk,idx,y,**base)
+print(f" full bf16: {T(bf)}",flush=True); print("DONE",flush=True)