summaryrefslogtreecommitdiff
path: root/ep_run/speed_probe.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/speed_probe.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/speed_probe.py')
-rw-r--r--ep_run/speed_probe.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/ep_run/speed_probe.py b/ep_run/speed_probe.py
new file mode 100644
index 0000000..0f0e3ba
--- /dev/null
+++ b/ep_run/speed_probe.py
@@ -0,0 +1,63 @@
+"""Speed-package probe for the 50M demo. Run on a free GPU (A6000 preferred).
+(1) torch.compile speedup on the relax loop (exact math, free speed).
+(2) bf16 force evals at r=0.2 with the TRACKING estimator: does the contrast survive low
+ precision when the nudge is large and the common mode cancels? (tf32 died at r=0.02+frozen;
+ this is the missing measurement that decides the 50M/1B cost sheet.)
+Outputs: it/s-equivalents + gradient cosine vs fp32 reference.
+"""
+import time, torch
+import lt_ep_train as M
+from pathlib import Path
+import pickle
+M.DD = Path('/tmp/lt_ep/data/tinystories')
+M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size']
+from lt_ep_train import EQBlock, get_batch, bptt_step, relax
+from holo_ep import holo_a_track, holo_a_select2
+
+dev = 'cuda'
+torch.manual_seed(0)
+B, T, C, H = 8, 256, 256, 8
+blk = EQBlock(C, H, 256, T, attn_mode='thick')
+ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt')
+for p, w in zip(blk.allp, ck['allp']):
+ with torch.no_grad():
+ p.copy_(w.to(dev))
+idx, y = get_batch('train', B, T)
+xin = blk.embed(idx).detach()
+
+# --- (1) compile speedup on relax ---
+t0 = time.time(); zs = relax(blk, xin.clone(), xin, 300, 0.1); torch.cuda.synchronize()
+base = time.time() - t0
+cfun = torch.compile(lambda z: z + 0.1 * blk.force(z, xin).detach(), mode='max-autotune-no-cudagraphs')
+z = xin.clone()
+for _ in range(10):
+ z = cfun(z) # warmup/compile
+torch.cuda.synchronize()
+t0 = time.time()
+z = xin.clone()
+for _ in range(300):
+ z = cfun(z)
+torch.cuda.synchronize()
+comp = time.time() - t0
+print(f"[compile] relax300: eager {base:.2f}s -> compiled {comp:.2f}s ({base/comp:.2f}x)", flush=True)
+
+# --- (2) bf16 @ r=0.2 + tracking ---
+aref, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1)
+def cos(a, b):
+ return (a.flatten() @ b.flatten() / (a.norm() * b.norm() + 1e-12)).item()
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+atf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1)
+print(f"[tf32 + track + r=0.2] cos vs fp32 = {cos(atf, aref):.3f}", flush=True)
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
+with torch.autocast('cuda', dtype=torch.bfloat16):
+ abf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1)
+abf = abf.float()
+print(f"[bf16 + track + r=0.2] cos vs fp32 = {cos(abf, aref):.3f}", flush=True)
+# also the old failure case for reference
+a_old_ref, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1)
+torch.backends.cuda.matmul.allow_tf32 = True
+a_old_tf, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1)
+torch.backends.cuda.matmul.allow_tf32 = False
+print(f"[tf32 + frozen + r=0.02 (known-dead control)] cos = {cos(a_old_tf, a_old_ref):.3f}", flush=True)