summaryrefslogtreecommitdiff
path: root/ep_run/speed_probe.py
blob: 0f0e3badca57b77ea8d71ea721a65d991760dff6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""Speed-package probe for the 50M demo. Run on a free GPU (A6000 preferred).
(1) torch.compile speedup on the relax loop (exact math, free speed).
(2) bf16 force evals at r=0.2 with the TRACKING estimator: does the contrast survive low
    precision when the nudge is large and the common mode cancels? (tf32 died at r=0.02+frozen;
    this is the missing measurement that decides the 50M/1B cost sheet.)
Outputs: it/s-equivalents + gradient cosine vs fp32 reference.
"""
import time, torch
import lt_ep_train as M
from pathlib import Path
import pickle
M.DD = Path('/tmp/lt_ep/data/tinystories')
M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size']
from lt_ep_train import EQBlock, get_batch, bptt_step, relax
from holo_ep import holo_a_track, holo_a_select2

dev = 'cuda'
torch.manual_seed(0)
B, T, C, H = 8, 256, 256, 8
blk = EQBlock(C, H, 256, T, attn_mode='thick')
ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt')
for p, w in zip(blk.allp, ck['allp']):
    with torch.no_grad():
        p.copy_(w.to(dev))
idx, y = get_batch('train', B, T)
xin = blk.embed(idx).detach()

# --- (1) compile speedup on relax ---
t0 = time.time(); zs = relax(blk, xin.clone(), xin, 300, 0.1); torch.cuda.synchronize()
base = time.time() - t0
cfun = torch.compile(lambda z: z + 0.1 * blk.force(z, xin).detach(), mode='max-autotune-no-cudagraphs')
z = xin.clone()
for _ in range(10):
    z = cfun(z)                                   # warmup/compile
torch.cuda.synchronize()
t0 = time.time()
z = xin.clone()
for _ in range(300):
    z = cfun(z)
torch.cuda.synchronize()
comp = time.time() - t0
print(f"[compile] relax300: eager {base:.2f}s -> compiled {comp:.2f}s ({base/comp:.2f}x)", flush=True)

# --- (2) bf16 @ r=0.2 + tracking ---
aref, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1)
def cos(a, b):
    return (a.flatten() @ b.flatten() / (a.norm() * b.norm() + 1e-12)).item()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
atf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1)
print(f"[tf32 + track + r=0.2]  cos vs fp32 = {cos(atf, aref):.3f}", flush=True)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
with torch.autocast('cuda', dtype=torch.bfloat16):
    abf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1)
abf = abf.float()
print(f"[bf16  + track + r=0.2] cos vs fp32 = {cos(abf, aref):.3f}", flush=True)
# also the old failure case for reference
a_old_ref, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1)
torch.backends.cuda.matmul.allow_tf32 = True
a_old_tf, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1)
torch.backends.cuda.matmul.allow_tf32 = False
print(f"[tf32 + frozen + r=0.02 (known-dead control)] cos = {cos(a_old_tf, a_old_ref):.3f}", flush=True)