1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
import torch, time, math, traceback
import lt_ep_train as LT
torch.manual_seed(0)
blk=LT.EQBlock(512,16,256,256,c=1.0,attn_mode='thick'); blk.qknorm=True; blk.track=True; blk.navg=1; blk.li_avg=0
ck=torch.load('runs/ep_resreg_warm.pt',map_location='cuda')
with torch.no_grad():
for p,s in zip(blk.allp,ck['allp']): p.copy_(s.to('cuda'))
idx,y=LT.get_batch('train',8,256)
base=dict(T1=150,T2=20,eps=0.1,beta=0.02,jacreg=0.1,holo=2,hr=0.2,t2sel=80,t1max=150,res_est=1e-4,resreg=0.2)
g32,_=LT.ep_step(blk,idx,y,**base)
def cos(ga):
n=da=db=0.0
for p in blk.block:
a=ga.get(id(p)); b=g32.get(id(p))
if a is None or b is None: continue
a=a.float(); b=b.float(); n+=float((a*b).sum()); da+=float((a*a).sum()); db+=float((b*b).sum())
return n/(math.sqrt(da*db)+1e-20)
def T(fn,reps=2):
fn(); torch.cuda.synchronize(); t0=time.time()
for _ in range(reps): fn()
torch.cuda.synchronize(); return round((time.time()-t0)/reps*1000)
print("fp32 step ms:", T(lambda: LT.ep_step(blk,idx,y,**base)),flush=True)
print("=== A: blanket autocast (locate the break) ===",flush=True)
try:
with torch.autocast('cuda',dtype=torch.bfloat16): gA,_=LT.ep_step(blk,idx,y,**base)
print("A OK cos",round(cos(gA),4),"ms",T(lambda: (lambda: [LT.ep_step(blk,idx,y,**base) for _ in '1'])() ))
except Exception:
traceback.print_exc()
print("DONE",flush=True)
|