summaryrefslogtreecommitdiff
path: root/ep_run/cos_sweep.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/cos_sweep.py')
-rw-r--r--ep_run/cos_sweep.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/ep_run/cos_sweep.py b/ep_run/cos_sweep.py
new file mode 100644
index 0000000..cd628fe
--- /dev/null
+++ b/ep_run/cos_sweep.py
@@ -0,0 +1,35 @@
+import torch, time, math, collections
+import lt_ep_train as LT
+torch.manual_seed(0)
+blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick')
+blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
+ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda')
+with torch.no_grad():
+ for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to('cuda'))
+print(f"loaded resreg_warm allp (step {ck['step']}, best {ck['best']:.4f})",flush=True)
+def cosine(ga,gb,params):
+ num=da=db=0.0
+ for p in params:
+ a=ga.get(id(p)); b=gb.get(id(p))
+ if a is None or b is None: continue
+ num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum())
+ return num/(math.sqrt(da*db)+1e-20)
+base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2)
+settings={'baseline':{}, 't2sel80':dict(t2sel=80), 't2sel40':dict(t2sel=40),
+ 'corr_every2':dict(corr_every=2), 'corr_every4':dict(corr_every=4),
+ 'plain_nudge_holo0':dict(holo=0,t2sel=0), 'no_t1max_refine':dict(t1max=0), 'T1=80':dict(T1=80)}
+cos=collections.defaultdict(list); NB=3; B=8
+for bi in range(NB):
+ idx,y=LT.get_batch('train',B,256)
+ gref=LT.bptt_step(blk,idx,y,300,0.1)
+ for n,kw in settings.items():
+ g,_=LT.ep_step(blk,idx,y,**{**base,**kw}); cos[n].append(cosine(g,gref,blk.block))
+ print(f"batch {bi} done",flush=True)
+idx,y=LT.get_batch('train',B,256)
+def T(fn):
+ fn(); torch.cuda.synchronize(); t0=time.time(); fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000)
+print(f"\n{'setting':22s}{'cos(block,vs BPTT)':>20s}{'ms/step(B8)':>13s}",flush=True)
+for n,kw in settings.items():
+ c=sum(cos[n])/len(cos[n]); t=T(lambda: LT.ep_step(blk,idx,y,**{**base,**kw}))
+ print(f"{n:22s}{c:>20.4f}{t:>13d}",flush=True)
+print("DONE",flush=True)