summaryrefslogtreecommitdiff
path: root/ep_run/test_compile_aselect.py
blob: a2601e409739f01a55bfe20d41cee9d00606a3cc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch, time, math
import lt_ep_train as LT, holo_ep as H
torch.manual_seed(0)
blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda')
with torch.no_grad():
    for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda'))
idx,y=LT.get_batch('train',24,256); xin=blk.embed(idx).detach()
zs=LT.relax(blk,xin.clone(),xin,150,0.1)
def T(fn):
    fn(); torch.cuda.synchronize(); t0=time.time(); a,_=fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000),a
ms0,a0=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1))
print(f"uncompiled a-select (t2sel=80): {ms0} ms",flush=True)
orig=blk.nc_force
try:
    blk.nc_force=torch.compile(blk.nc_force, mode='reduce-overhead')
    ms1,a1=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1))
    cos=float((a0.flatten()@a1.flatten())/(a0.norm()*a1.norm()+1e-20))
    print(f"compiled nc_force: {ms1} ms  cos(a vs uncompiled)={cos:.4f}",flush=True)
except Exception as e:
    print("compile FAIL:", type(e).__name__, str(e)[:120],flush=True)
blk.nc_force=orig
print("DONE",flush=True)