import torch, time, math, collections import lt_ep_train as LT torch.manual_seed(0) blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') with torch.no_grad(): for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to('cuda')) print(f"loaded resreg_warm allp (step {ck['step']}, best {ck['best']:.4f})",flush=True) def cosine(ga,gb,params): num=da=db=0.0 for p in params: a=ga.get(id(p)); b=gb.get(id(p)) if a is None or b is None: continue num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) return num/(math.sqrt(da*db)+1e-20) base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2) settings={'baseline':{}, 't2sel80':dict(t2sel=80), 't2sel40':dict(t2sel=40), 'corr_every2':dict(corr_every=2), 'corr_every4':dict(corr_every=4), 'plain_nudge_holo0':dict(holo=0,t2sel=0), 'no_t1max_refine':dict(t1max=0), 'T1=80':dict(T1=80)} cos=collections.defaultdict(list); NB=3; B=8 for bi in range(NB): idx,y=LT.get_batch('train',B,256) gref=LT.bptt_step(blk,idx,y,300,0.1) for n,kw in settings.items(): g,_=LT.ep_step(blk,idx,y,**{**base,**kw}); cos[n].append(cosine(g,gref,blk.block)) print(f"batch {bi} done",flush=True) idx,y=LT.get_batch('train',B,256) def T(fn): fn(); torch.cuda.synchronize(); t0=time.time(); fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000) print(f"\n{'setting':22s}{'cos(block,vs BPTT)':>20s}{'ms/step(B8)':>13s}",flush=True) for n,kw in settings.items(): c=sum(cos[n])/len(cos[n]); t=T(lambda: LT.ep_step(blk,idx,y,**{**base,**kw})) print(f"{n:22s}{c:>20.4f}{t:>13d}",flush=True) print("DONE",flush=True)