diff options
Diffstat (limited to 'ep_run/cos_sweep.py')
| -rw-r--r-- | ep_run/cos_sweep.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/ep_run/cos_sweep.py b/ep_run/cos_sweep.py new file mode 100644 index 0000000..cd628fe --- /dev/null +++ b/ep_run/cos_sweep.py @@ -0,0 +1,35 @@ +import torch, time, math, collections +import lt_ep_train as LT +torch.manual_seed(0) +blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick') +blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0 +ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda') +with torch.no_grad(): + for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to('cuda')) +print(f"loaded resreg_warm allp (step {ck['step']}, best {ck['best']:.4f})",flush=True) +def cosine(ga,gb,params): + num=da=db=0.0 + for p in params: + a=ga.get(id(p)); b=gb.get(id(p)) + if a is None or b is None: continue + num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum()) + return num/(math.sqrt(da*db)+1e-20) +base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2) +settings={'baseline':{}, 't2sel80':dict(t2sel=80), 't2sel40':dict(t2sel=40), + 'corr_every2':dict(corr_every=2), 'corr_every4':dict(corr_every=4), + 'plain_nudge_holo0':dict(holo=0,t2sel=0), 'no_t1max_refine':dict(t1max=0), 'T1=80':dict(T1=80)} +cos=collections.defaultdict(list); NB=3; B=8 +for bi in range(NB): + idx,y=LT.get_batch('train',B,256) + gref=LT.bptt_step(blk,idx,y,300,0.1) + for n,kw in settings.items(): + g,_=LT.ep_step(blk,idx,y,**{**base,**kw}); cos[n].append(cosine(g,gref,blk.block)) + print(f"batch {bi} done",flush=True) +idx,y=LT.get_batch('train',B,256) +def T(fn): + fn(); torch.cuda.synchronize(); t0=time.time(); fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000) +print(f"\n{'setting':22s}{'cos(block,vs BPTT)':>20s}{'ms/step(B8)':>13s}",flush=True) +for n,kw in settings.items(): + c=sum(cos[n])/len(cos[n]); t=T(lambda: LT.ep_step(blk,idx,y,**{**base,**kw})) + print(f"{n:22s}{c:>20.4f}{t:>13d}",flush=True) +print("DONE",flush=True) |
