summaryrefslogtreecommitdiff
path: root/ep_run/bf16_dbg.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/bf16_dbg.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/bf16_dbg.py')
-rw-r--r--ep_run/bf16_dbg.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/ep_run/bf16_dbg.py b/ep_run/bf16_dbg.py
new file mode 100644
index 0000000..de32aed
--- /dev/null
+++ b/ep_run/bf16_dbg.py
@@ -0,0 +1,29 @@
+import torch, time, math, traceback
+import lt_ep_train as LT
+torch.manual_seed(0)
+blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
+ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda')
+with torch.no_grad():
+ for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda'))
+idx,y=LT.get_batch('train',8,256)
+base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=80,t1max=150,res_est=1e-4,resreg=0.2)
+g32,_=LT.ep_step(blk,idx,y,**base)
+def cos(ga):
+ n=da=db=0.0
+ for p in blk.block:
+ a=ga.get(id(p)); b=g32.get(id(p))
+ if a is None or b is None: continue
+ a=a.float(); b=b.float(); n+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum())
+ return n/(math.sqrt(da*db)+1e-20)
+def T(fn,reps=2):
+ fn(); torch.cuda.synchronize(); t0=time.time()
+ for _ in range(reps): fn()
+ torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000)
+print("fp32 step ms:", T(lambda: LT.ep_step(blk,idx,y,**base)),flush=True)
+print("=== A: blanket autocast (locate the break) ===",flush=True)
+try:
+ with torch.autocast('cuda',dtype=torch.bfloat16): gA,_=LT.ep_step(blk,idx,y,**base)
+ print("A OK cos",round(cos(gA),4),"ms",T(lambda: (lambda: [LT.ep_step(blk,idx,y,**base) for _ in '1'])() ))
+except Exception:
+ traceback.print_exc()
+print("DONE",flush=True)