diff options
Diffstat (limited to 'ep_run/speed_probe.py')
| -rw-r--r-- | ep_run/speed_probe.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/ep_run/speed_probe.py b/ep_run/speed_probe.py new file mode 100644 index 0000000..0f0e3ba --- /dev/null +++ b/ep_run/speed_probe.py @@ -0,0 +1,63 @@ +"""Speed-package probe for the 50M demo. Run on a free GPU (A6000 preferred). +(1) torch.compile speedup on the relax loop (exact math, free speed). +(2) bf16 force evals at r=0.2 with the TRACKING estimator: does the contrast survive low + precision when the nudge is large and the common mode cancels? (tf32 died at r=0.02+frozen; + this is the missing measurement that decides the 50M/1B cost sheet.) +Outputs: it/s-equivalents + gradient cosine vs fp32 reference. +""" +import time, torch +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories') +M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, relax +from holo_ep import holo_a_track, holo_a_select2 + +dev = 'cuda' +torch.manual_seed(0) +B, T, C, H = 8, 256, 256, 8 +blk = EQBlock(C, H, 256, T, attn_mode='thick') +ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt') +for p, w in zip(blk.allp, ck['allp']): + with torch.no_grad(): + p.copy_(w.to(dev)) +idx, y = get_batch('train', B, T) +xin = blk.embed(idx).detach() + +# --- (1) compile speedup on relax --- +t0 = time.time(); zs = relax(blk, xin.clone(), xin, 300, 0.1); torch.cuda.synchronize() +base = time.time() - t0 +cfun = torch.compile(lambda z: z + 0.1 * blk.force(z, xin).detach(), mode='max-autotune-no-cudagraphs') +z = xin.clone() +for _ in range(10): + z = cfun(z) # warmup/compile +torch.cuda.synchronize() +t0 = time.time() +z = xin.clone() +for _ in range(300): + z = cfun(z) +torch.cuda.synchronize() +comp = time.time() - t0 +print(f"[compile] relax300: eager {base:.2f}s -> compiled {comp:.2f}s ({base/comp:.2f}x)", flush=True) + +# --- (2) bf16 @ r=0.2 + tracking --- +aref, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) +def cos(a, b): + return (a.flatten() @ b.flatten() / (a.norm() * b.norm() + 1e-12)).item() +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True +atf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) +print(f"[tf32 + track + r=0.2] cos vs fp32 = {cos(atf, aref):.3f}", flush=True) +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False +with torch.autocast('cuda', dtype=torch.bfloat16): + abf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) +abf = abf.float() +print(f"[bf16 + track + r=0.2] cos vs fp32 = {cos(abf, aref):.3f}", flush=True) +# also the old failure case for reference +a_old_ref, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1) +torch.backends.cuda.matmul.allow_tf32 = True +a_old_tf, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1) +torch.backends.cuda.matmul.allow_tf32 = False +print(f"[tf32 + frozen + r=0.02 (known-dead control)] cos = {cos(a_old_tf, a_old_ref):.3f}", flush=True) |
