diff options
Diffstat (limited to 'ep_run/compile_bench.py')
| -rw-r--r-- | ep_run/compile_bench.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/ep_run/compile_bench.py b/ep_run/compile_bench.py new file mode 100644 index 0000000..dc730dd --- /dev/null +++ b/ep_run/compile_bench.py @@ -0,0 +1,44 @@ +import torch, pickle, time +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=24; T=256; eps=0.1; T1=150 +torch.manual_seed(0); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev) +blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True +with torch.no_grad(): + for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) + xin=blk.embed(idx).detach() +print("torch", torch.__version__) +def relax(step,n): + z=xin.clone() + for _ in range(n): z=step(z) + return z +@torch.no_grad() +def bench(step,label,reps=8): + for _ in range(3): relax(step,T1); torch.cuda.synchronize() # warmup + torch.cuda.synchronize(); t0=time.time() + for _ in range(reps): relax(step,T1) + torch.cuda.synchronize(); return (time.time()-t0)/reps*1000 +def fstep(z): + with torch.no_grad(): return z+eps*blk.force(z,xin).detach() +te=bench(fstep,"eager") +print(f"FP32 eager 150-relax: {te:.1f} ms/relax") +try: + cstep=torch.compile(fstep) + tc=bench(cstep,"compiled") + print(f"FP32 compile : {tc:.1f} ms/relax -> {te/tc:.2f}x") +except Exception as e: print("compile default ERR:", str(e)[:200]) +try: + cstep2=torch.compile(fstep, mode="reduce-overhead") + tc2=bench(cstep2,"reduce-overhead") + print(f"FP32 reduce-overhead : {tc2:.1f} ms/relax -> {te/tc2:.2f}x (CUDA graphs)") +except Exception as e: print("reduce-overhead ERR:", str(e)[:200]) +# bf16 potential (timing only; precision caveat for actual use) +def fstep_bf(z): + with torch.no_grad(), torch.autocast('cuda',dtype=torch.bfloat16): return z+eps*blk.force(z,xin).detach() +try: + tb=bench(fstep_bf,"bf16"); print(f"BF16 eager (autocast): {tb:.1f} ms/relax -> {te/tb:.2f}x (precision caveat)") +except Exception as e: print("bf16 ERR:", str(e)[:200]) +print("=== DONE ===") |
