summaryrefslogtreecommitdiff
path: root/ep_run/test_compile_aselect.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/test_compile_aselect.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/test_compile_aselect.py')
-rw-r--r--ep_run/test_compile_aselect.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/ep_run/test_compile_aselect.py b/ep_run/test_compile_aselect.py
new file mode 100644
index 0000000..a2601e4
--- /dev/null
+++ b/ep_run/test_compile_aselect.py
@@ -0,0 +1,23 @@
+import torch, time, math
+import lt_ep_train as LT, holo_ep as H
+torch.manual_seed(0)
+blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
+ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda')
+with torch.no_grad():
+ for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda'))
+idx,y=LT.get_batch('train',24,256); xin=blk.embed(idx).detach()
+zs=LT.relax(blk,xin.clone(),xin,150,0.1)
+def T(fn):
+ fn(); torch.cuda.synchronize(); t0=time.time(); a,_=fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000),a
+ms0,a0=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1))
+print(f"uncompiled a-select (t2sel=80): {ms0} ms",flush=True)
+orig=blk.nc_force
+try:
+ blk.nc_force=torch.compile(blk.nc_force, mode='reduce-overhead')
+ ms1,a1=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1))
+ cos=float((a0.flatten()@a1.flatten())/(a0.norm()*a1.norm()+1e-20))
+ print(f"compiled nc_force: {ms1} ms cos(a vs uncompiled)={cos:.4f}",flush=True)
+except Exception as e:
+ print("compile FAIL:", type(e).__name__, str(e)[:120],flush=True)
+blk.nc_force=orig
+print("DONE",flush=True)