1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
import torch, time, math, collections
import lt_ep_train as LT
torch.manual_seed(0)
blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick')
blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda')
with torch.no_grad():
for p,s in zip(blk.allp, ck['allp']): p.copy_(s.to('cuda'))
print(f"loaded resreg_warm allp (step {ck['step']}, best {ck['best']:.4f})",flush=True)
def cosine(ga,gb,params):
num=da=db=0.0
for p in params:
a=ga.get(id(p)); b=gb.get(id(p))
if a is None or b is None: continue
num+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum())
return num/(math.sqrt(da*db)+1e-20)
base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=160,t1max=300,res_est=1e-4,resreg=0.2)
settings={'baseline':{}, 't2sel80':dict(t2sel=80), 't2sel40':dict(t2sel=40),
'corr_every2':dict(corr_every=2), 'corr_every4':dict(corr_every=4),
'plain_nudge_holo0':dict(holo=0,t2sel=0), 'no_t1max_refine':dict(t1max=0), 'T1=80':dict(T1=80)}
cos=collections.defaultdict(list); NB=3; B=8
for bi in range(NB):
idx,y=LT.get_batch('train',B,256)
gref=LT.bptt_step(blk,idx,y,300,0.1)
for n,kw in settings.items():
g,_=LT.ep_step(blk,idx,y,**{**base,**kw}); cos[n].append(cosine(g,gref,blk.block))
print(f"batch {bi} done",flush=True)
idx,y=LT.get_batch('train',B,256)
def T(fn):
fn(); torch.cuda.synchronize(); t0=time.time(); fn(); torch.cuda.synchronize(); return round((time.time()-t0)*1000)
print(f"\n{'setting':22s}{'cos(block,vs BPTT)':>20s}{'ms/step(B8)':>13s}",flush=True)
for n,kw in settings.items():
c=sum(cos[n])/len(cos[n]); t=T(lambda: LT.ep_step(blk,idx,y,**{**base,**kw}))
print(f"{n:22s}{c:>20.4f}{t:>13d}",flush=True)
print("DONE",flush=True)
|