diff options
Diffstat (limited to 'ep_run/profile_ep.py')
| -rw-r--r-- | ep_run/profile_ep.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/ep_run/profile_ep.py b/ep_run/profile_ep.py new file mode 100644 index 0000000..f76c4f2 --- /dev/null +++ b/ep_run/profile_ep.py @@ -0,0 +1,40 @@ +import torch, time, math +import lt_ep_train as LT +torch.manual_seed(0) +def mk(): + blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') + blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 + return blk +def T(fn,reps=3,warm=1): + try: + torch.cuda.empty_cache() + for _ in range(warm): fn() + torch.cuda.synchronize(); t0=time.time() + for _ in range(reps): fn() + torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000) + except Exception as e: + return f"ERR {type(e).__name__}: {str(e)[:70]}" +blk=mk(); idx,y=LT.get_batch('train',24,256) +base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2) +S=lambda **kw: (lambda: LT.ep_step(blk,idx,y,**{**base,**kw})) +print("=== full + component toggles (ms/step, B=24, C512) ===",flush=True) +full=T(S()); print(f"FULL ep_step: {full}",flush=True) +for n,kw in [("-jacreg",dict(jacreg=0)),("-resreg",dict(resreg=0)),("-t1max(no refine)",dict(t1max=0)), + ("t2sel=80",dict(t2sel=80)),("t2sel=40",dict(t2sel=40)),("plain nudge holo=0 T2=20",dict(holo=0,t2sel=0))]: + print(f" {n}: {T(S(**kw))}",flush=True) +xin=blk.embed(idx).detach() +print(f" free relax T1=150 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True) +print(f" free relax T1=300 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,300,0.1))}",flush=True) +print("=== batch sweep (full) ===",flush=True) +for B in [8,24,48]: + ib,yb=LT.get_batch('train',B,256); t=T(lambda: LT.ep_step(blk,ib,yb,**base)) + print(f" B={B}: {t} ms"+(f" ({t/B:.1f}/sample)" if isinstance(t,(int,float)) else ""),flush=True) +print("=== compile free relax ===",flush=True) +try: + blk._cstep=torch.compile(lambda z,xn: z+0.1*blk.tforce(z,xn)) + print(f" free relax T1=150 COMPILED: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True); del blk._cstep +except Exception as e: print(f" compile ERR {e}",flush=True) +print("=== bf16 full ===",flush=True) +def bf(): + with torch.autocast('cuda',dtype=torch.bfloat16): LT.ep_step(blk,idx,y,**base) +print(f" full bf16: {T(bf)}",flush=True); print("DONE",flush=True) |
