diff options
Diffstat (limited to 'ep_run/test_compile_aselect.py')
| -rw-r--r-- | ep_run/test_compile_aselect.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/ep_run/test_compile_aselect.py b/ep_run/test_compile_aselect.py new file mode 100644 index 0000000..a2601e4 --- /dev/null +++ b/ep_run/test_compile_aselect.py @@ -0,0 +1,23 @@ +import torch, time, math +import lt_ep_train as LT, holo_ep as H +torch.manual_seed(0) +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') +with torch.no_grad(): + for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda')) +idx,y=LT.get_batch('train',24,256); xin=blk.embed(idx).detach() +zs=LT.relax(blk,xin.clone(),xin,150,0.1) +def T(fn): + fn(); torch.cuda.synchronize(); t0=time.time(); a,_=fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000),a +ms0,a0=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1)) +print(f"uncompiled a-select (t2sel=80): {ms0} ms",flush=True) +orig=blk.nc_force +try: + blk.nc_force=torch.compile(blk.nc_force, mode='reduce-overhead') + ms1,a1=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1)) + cos=float((a0.flatten()@a1.flatten())/(a0.norm()*a1.norm()+1e-20)) + print(f"compiled nc_force: {ms1} ms cos(a vs uncompiled)={cos:.4f}",flush=True) +except Exception as e: + print("compile FAIL:", type(e).__name__, str(e)[:120],flush=True) +blk.nc_force=orig +print("DONE",flush=True) |
