summaryrefslogtreecommitdiff
path: root/ep_run/compile_bench.py
blob: dc730dd86d3f93af5f7c994ef1711e5c88860c75 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch, pickle, time
from pathlib import Path
import lt_ep_train as L
from lt_ep_train import EQBlock
L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size']
dev='cuda'; B=24; T=256; eps=0.1; T1=150
torch.manual_seed(0); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx
ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev)
blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True
with torch.no_grad():
    for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev))
    xin=blk.embed(idx).detach()
print("torch", torch.__version__)
def relax(step,n):
    z=xin.clone()
    for _ in range(n): z=step(z)
    return z
@torch.no_grad()
def bench(step,label,reps=8):
    for _ in range(3): relax(step,T1); torch.cuda.synchronize()   # warmup
    torch.cuda.synchronize(); t0=time.time()
    for _ in range(reps): relax(step,T1)
    torch.cuda.synchronize(); return (time.time()-t0)/reps*1000
def fstep(z): 
    with torch.no_grad(): return z+eps*blk.force(z,xin).detach()
te=bench(fstep,"eager")
print(f"FP32 eager  150-relax: {te:.1f} ms/relax")
try:
    cstep=torch.compile(fstep)
    tc=bench(cstep,"compiled")
    print(f"FP32 compile         : {tc:.1f} ms/relax  -> {te/tc:.2f}x")
except Exception as e: print("compile default ERR:", str(e)[:200])
try:
    cstep2=torch.compile(fstep, mode="reduce-overhead")
    tc2=bench(cstep2,"reduce-overhead")
    print(f"FP32 reduce-overhead : {tc2:.1f} ms/relax  -> {te/tc2:.2f}x  (CUDA graphs)")
except Exception as e: print("reduce-overhead ERR:", str(e)[:200])
# bf16 potential (timing only; precision caveat for actual use)
def fstep_bf(z):
    with torch.no_grad(), torch.autocast('cuda',dtype=torch.bfloat16): return z+eps*blk.force(z,xin).detach()
try:
    tb=bench(fstep_bf,"bf16"); print(f"BF16 eager (autocast): {tb:.1f} ms/relax  -> {te/tb:.2f}x  (precision caveat)")
except Exception as e: print("bf16 ERR:", str(e)[:200])
print("=== DONE ===")