import torch, time, math import lt_ep_train as LT, holo_ep as H torch.manual_seed(0) blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') with torch.no_grad(): for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda')) idx,y=LT.get_batch('train',24,256); xin=blk.embed(idx).detach() zs=LT.relax(blk,xin.clone(),xin,150,0.1) def T(fn): fn(); torch.cuda.synchronize(); t0=time.time(); a,_=fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000),a ms0,a0=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1)) print(f"uncompiled a-select (t2sel=80): {ms0} ms",flush=True) orig=blk.nc_force try: blk.nc_force=torch.compile(blk.nc_force, mode='reduce-overhead') ms1,a1=T(lambda: H.holo_a_track(blk,zs,xin,y,0.02,80,0.1)) cos=float((a0.flatten()@a1.flatten())/(a0.norm()*a1.norm()+1e-20)) print(f"compiled nc_force: {ms1} ms cos(a vs uncompiled)={cos:.4f}",flush=True) except Exception as e: print("compile FAIL:", type(e).__name__, str(e)[:120],flush=True) blk.nc_force=orig print("DONE",flush=True)