summaryrefslogtreecommitdiff
path: root/ep_run/compile_bench.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/compile_bench.py')
-rw-r--r--ep_run/compile_bench.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/ep_run/compile_bench.py b/ep_run/compile_bench.py
new file mode 100644
index 0000000..dc730dd
--- /dev/null
+++ b/ep_run/compile_bench.py
@@ -0,0 +1,44 @@
+import torch, pickle, time
+from pathlib import Path
+import lt_ep_train as L
+from lt_ep_train import EQBlock
+L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size']
+dev='cuda'; B=24; T=256; eps=0.1; T1=150
+torch.manual_seed(0); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx
+ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev)
+blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True
+with torch.no_grad():
+ for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev))
+ xin=blk.embed(idx).detach()
+print("torch", torch.__version__)
+def relax(step,n):
+ z=xin.clone()
+ for _ in range(n): z=step(z)
+ return z
+@torch.no_grad()
+def bench(step,label,reps=8):
+ for _ in range(3): relax(step,T1); torch.cuda.synchronize() # warmup
+ torch.cuda.synchronize(); t0=time.time()
+ for _ in range(reps): relax(step,T1)
+ torch.cuda.synchronize(); return (time.time()-t0)/reps*1000
+def fstep(z):
+ with torch.no_grad(): return z+eps*blk.force(z,xin).detach()
+te=bench(fstep,"eager")
+print(f"FP32 eager 150-relax: {te:.1f} ms/relax")
+try:
+ cstep=torch.compile(fstep)
+ tc=bench(cstep,"compiled")
+ print(f"FP32 compile : {tc:.1f} ms/relax -> {te/tc:.2f}x")
+except Exception as e: print("compile default ERR:", str(e)[:200])
+try:
+ cstep2=torch.compile(fstep, mode="reduce-overhead")
+ tc2=bench(cstep2,"reduce-overhead")
+ print(f"FP32 reduce-overhead : {tc2:.1f} ms/relax -> {te/tc2:.2f}x (CUDA graphs)")
+except Exception as e: print("reduce-overhead ERR:", str(e)[:200])
+# bf16 potential (timing only; precision caveat for actual use)
+def fstep_bf(z):
+ with torch.no_grad(), torch.autocast('cuda',dtype=torch.bfloat16): return z+eps*blk.force(z,xin).detach()
+try:
+ tb=bench(fstep_bf,"bf16"); print(f"BF16 eager (autocast): {tb:.1f} ms/relax -> {te/tb:.2f}x (precision caveat)")
+except Exception as e: print("bf16 ERR:", str(e)[:200])
+print("=== DONE ===")