diff options
Diffstat (limited to 'ep_run/bf16_dbg2.py')
| -rw-r--r-- | ep_run/bf16_dbg2.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/ep_run/bf16_dbg2.py b/ep_run/bf16_dbg2.py new file mode 100644 index 0000000..517642d --- /dev/null +++ b/ep_run/bf16_dbg2.py @@ -0,0 +1,30 @@ +import torch, time, math +import lt_ep_train as LT +torch.manual_seed(0) +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') +with torch.no_grad(): + for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda')) +idx,y=LT.get_batch('train',8,256) +base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=80,t1max=150,res_est=1e-4,resreg=0.2) +g32,_=LT.ep_step(blk,idx,y,**base) +def cos(ga): + n=da=db=0.0 + for p in blk.block: + a=ga.get(id(p)); b=g32.get(id(p)) + if a is None or b is None: continue + a=a.float(); b=b.float(); n+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) + return round(n/(math.sqrt(da*db)+1e-20),4) +def T(fn,reps=2): + fn(); torch.cuda.synchronize(); t0=time.time() + for _ in range(reps): fn() + torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000) +fp32ms=T(lambda: LT.ep_step(blk,idx,y,**base)); print("fp32:",fp32ms,"ms",flush=True) +def trial(name,**ac): + try: + def run(): + with torch.autocast('cuda',dtype=torch.bfloat16,**ac): return LT.ep_step(blk,idx,y,**base)[0] + g=run(); print(f"{name:22s} OK cos={cos(g)} ms={T(run)} (fp32={fp32ms})",flush=True) + except Exception as e: print(f"{name:22s} FAIL: {type(e).__name__}: {str(e)[:80]}",flush=True) +trial("cache_enabled=False", cache_enabled=False) +print("DONE",flush=True) |
