import torch, time, math import lt_ep_train as LT torch.manual_seed(0) def mk(): blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 return blk def T(fn,reps=3,warm=1): try: torch.cuda.empty_cache() for _ in range(warm): fn() torch.cuda.synchronize(); t0=time.time() for _ in range(reps): fn() torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000) except Exception as e: return f"ERR {type(e).__name__}: {str(e)[:70]}" blk=mk(); idx,y=LT.get_batch('train',24,256) base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2) S=lambda **kw: (lambda: LT.ep_step(blk,idx,y,**{**base,**kw})) print("=== full + component toggles (ms/step, B=24, C512) ===",flush=True) full=T(S()); print(f"FULL ep_step: {full}",flush=True) for n,kw in [("-jacreg",dict(jacreg=0)),("-resreg",dict(resreg=0)),("-t1max(no refine)",dict(t1max=0)), ("t2sel=80",dict(t2sel=80)),("t2sel=40",dict(t2sel=40)),("plain nudge holo=0 T2=20",dict(holo=0,t2sel=0))]: print(f" {n}: {T(S(**kw))}",flush=True) xin=blk.embed(idx).detach() print(f" free relax T1=150 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True) print(f" free relax T1=300 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,300,0.1))}",flush=True) print("=== batch sweep (full) ===",flush=True) for B in [8,24,48]: ib,yb=LT.get_batch('train',B,256); t=T(lambda: LT.ep_step(blk,ib,yb,**base)) print(f" B={B}: {t} ms"+(f" ({t/B:.1f}/sample)" if isinstance(t,(int,float)) else ""),flush=True) print("=== compile free relax ===",flush=True) try: blk._cstep=torch.compile(lambda z,xn: z+0.1*blk.tforce(z,xn)) print(f" free relax T1=150 COMPILED: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True); del blk._cstep except Exception as e: print(f" compile ERR {e}",flush=True) print("=== bf16 full ===",flush=True) def bf(): with torch.autocast('cuda',dtype=torch.bfloat16): LT.ep_step(blk,idx,y,**base) print(f" full bf16: {T(bf)}",flush=True); print("DONE",flush=True)