import torch, pickle, time from pathlib import Path import lt_ep_train as L from lt_ep_train import EQBlock L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] dev='cuda'; B=24; T=256; eps=0.1; T1=150 torch.manual_seed(0); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev) blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True with torch.no_grad(): for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev)) xin=blk.embed(idx).detach() print("torch", torch.__version__) def relax(step,n): z=xin.clone() for _ in range(n): z=step(z) return z @torch.no_grad() def bench(step,label,reps=8): for _ in range(3): relax(step,T1); torch.cuda.synchronize() # warmup torch.cuda.synchronize(); t0=time.time() for _ in range(reps): relax(step,T1) torch.cuda.synchronize(); return (time.time()-t0)/reps*1000 def fstep(z): with torch.no_grad(): return z+eps*blk.force(z,xin).detach() te=bench(fstep,"eager") print(f"FP32 eager 150-relax: {te:.1f} ms/relax") try: cstep=torch.compile(fstep) tc=bench(cstep,"compiled") print(f"FP32 compile : {tc:.1f} ms/relax -> {te/tc:.2f}x") except Exception as e: print("compile default ERR:", str(e)[:200]) try: cstep2=torch.compile(fstep, mode="reduce-overhead") tc2=bench(cstep2,"reduce-overhead") print(f"FP32 reduce-overhead : {tc2:.1f} ms/relax -> {te/tc2:.2f}x (CUDA graphs)") except Exception as e: print("reduce-overhead ERR:", str(e)[:200]) # bf16 potential (timing only; precision caveat for actual use) def fstep_bf(z): with torch.no_grad(), torch.autocast('cuda',dtype=torch.bfloat16): return z+eps*blk.force(z,xin).detach() try: tb=bench(fstep_bf,"bf16"); print(f"BF16 eager (autocast): {tb:.1f} ms/relax -> {te/tb:.2f}x (precision caveat)") except Exception as e: print("bf16 ERR:", str(e)[:200]) print("=== DONE ===")