"""Speed-package probe for the 50M demo. Run on a free GPU (A6000 preferred). (1) torch.compile speedup on the relax loop (exact math, free speed). (2) bf16 force evals at r=0.2 with the TRACKING estimator: does the contrast survive low precision when the nudge is large and the common mode cancels? (tf32 died at r=0.02+frozen; this is the missing measurement that decides the 50M/1B cost sheet.) Outputs: it/s-equivalents + gradient cosine vs fp32 reference. """ import time, torch import lt_ep_train as M from pathlib import Path import pickle M.DD = Path('/tmp/lt_ep/data/tinystories') M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] from lt_ep_train import EQBlock, get_batch, bptt_step, relax from holo_ep import holo_a_track, holo_a_select2 dev = 'cuda' torch.manual_seed(0) B, T, C, H = 8, 256, 256, 8 blk = EQBlock(C, H, 256, T, attn_mode='thick') ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt') for p, w in zip(blk.allp, ck['allp']): with torch.no_grad(): p.copy_(w.to(dev)) idx, y = get_batch('train', B, T) xin = blk.embed(idx).detach() # --- (1) compile speedup on relax --- t0 = time.time(); zs = relax(blk, xin.clone(), xin, 300, 0.1); torch.cuda.synchronize() base = time.time() - t0 cfun = torch.compile(lambda z: z + 0.1 * blk.force(z, xin).detach(), mode='max-autotune-no-cudagraphs') z = xin.clone() for _ in range(10): z = cfun(z) # warmup/compile torch.cuda.synchronize() t0 = time.time() z = xin.clone() for _ in range(300): z = cfun(z) torch.cuda.synchronize() comp = time.time() - t0 print(f"[compile] relax300: eager {base:.2f}s -> compiled {comp:.2f}s ({base/comp:.2f}x)", flush=True) # --- (2) bf16 @ r=0.2 + tracking --- aref, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) def cos(a, b): return (a.flatten() @ b.flatten() / (a.norm() * b.norm() + 1e-12)).item() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True atf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) print(f"[tf32 + track + r=0.2] cos vs fp32 = {cos(atf, aref):.3f}", flush=True) torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False with torch.autocast('cuda', dtype=torch.bfloat16): abf, _ = holo_a_track(blk, zs, xin, y, 0.2, 120, 0.1) abf = abf.float() print(f"[bf16 + track + r=0.2] cos vs fp32 = {cos(abf, aref):.3f}", flush=True) # also the old failure case for reference a_old_ref, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1) torch.backends.cuda.matmul.allow_tf32 = True a_old_tf, _ = holo_a_select2(blk, zs, xin, y, 0.02, 120, 0.1) torch.backends.cuda.matmul.allow_tf32 = False print(f"[tf32 + frozen + r=0.02 (known-dead control)] cos = {cos(a_old_tf, a_old_ref):.3f}", flush=True)