summaryrefslogtreecommitdiff
path: root/ep_run/profile_ep.py
blob: f76c4f2ca652fc22401b0bce86aa7b52585c5e1c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch, time, math
import lt_ep_train as LT
torch.manual_seed(0)
def mk():
    blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick')
    blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
    return blk
def T(fn,reps=3,warm=1):
    try:
        torch.cuda.empty_cache()
        for _ in range(warm): fn()
        torch.cuda.synchronize(); t0=time.time()
        for _ in range(reps): fn()
        torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000)
    except Exception as e:
        return f"ERR {type(e).__name__}: {str(e)[:70]}"
blk=mk(); idx,y=LT.get_batch('train',24,256)
base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2)
S=lambda **kw: (lambda: LT.ep_step(blk,idx,y,**{**base,**kw}))
print("=== full + component toggles (ms/step, B=24, C512) ===",flush=True)
full=T(S()); print(f"FULL ep_step: {full}",flush=True)
for n,kw in [("-jacreg",dict(jacreg=0)),("-resreg",dict(resreg=0)),("-t1max(no refine)",dict(t1max=0)),
             ("t2sel=80",dict(t2sel=80)),("t2sel=40",dict(t2sel=40)),("plain nudge holo=0 T2=20",dict(holo=0,t2sel=0))]:
    print(f"  {n}: {T(S(**kw))}",flush=True)
xin=blk.embed(idx).detach()
print(f"  free relax T1=150 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True)
print(f"  free relax T1=300 alone: {T(lambda: LT.relax(blk,xin.clone(),xin,300,0.1))}",flush=True)
print("=== batch sweep (full) ===",flush=True)
for B in [8,24,48]:
    ib,yb=LT.get_batch('train',B,256); t=T(lambda: LT.ep_step(blk,ib,yb,**base))
    print(f"  B={B}: {t} ms"+(f"  ({t/B:.1f}/sample)" if isinstance(t,(int,float)) else ""),flush=True)
print("=== compile free relax ===",flush=True)
try:
    blk._cstep=torch.compile(lambda z,xn: z+0.1*blk.tforce(z,xn))
    print(f"  free relax T1=150 COMPILED: {T(lambda: LT.relax(blk,xin.clone(),xin,150,0.1))}",flush=True); del blk._cstep
except Exception as e: print(f"  compile ERR {e}",flush=True)
print("=== bf16 full ===",flush=True)
def bf():
    with torch.autocast('cuda',dtype=torch.bfloat16): LT.ep_step(blk,idx,y,**base)
print(f"  full bf16: {T(bf)}",flush=True); print("DONE",flush=True)