summaryrefslogtreecommitdiff
path: root/ep_run/fast_probe.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/fast_probe.py')
-rw-r--r--ep_run/fast_probe.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/ep_run/fast_probe.py b/ep_run/fast_probe.py
new file mode 100644
index 0000000..6fa72a2
--- /dev/null
+++ b/ep_run/fast_probe.py
@@ -0,0 +1,41 @@
+"""Validate the speed knobs: corr_every (stale AEP corr) x tf32, vs fp32 exact reference.
+Watch BOTH gradient cosine AND the achievable free-phase residual (tf32 may raise the res floor
+above res_est=1e-4 -> validity issue)."""
+import time, torch
+from lt_ep_train import EQBlock, get_batch, bptt_step, relax
+from holo_ep import holo_a_select
+torch.manual_seed(0)
+B, T = 16, 64
+blk = EQBlock(128, 4, 256, T, attn_mode='thick')
+for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')):
+ with torch.no_grad():
+ p.copy_(w.to('cuda'))
+
+def cos(ga, gb, ps):
+ keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None]
+ va = torch.cat([ga[id(p)].reshape(-1) for p in keep]); vb = torch.cat([gb[id(p)].reshape(-1) for p in keep])
+ return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()
+
+def gfrom(idx, zs, a):
+ with torch.enable_grad():
+ xin = blk.embed(idx)
+ f = blk.force(zs.detach(), xin, cg=True)
+ g = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True)
+ return {id(p): gv for p, gv in zip(blk.block, g)}
+
+print(f"{'tf32':>5} {'corr_ev':>8} {'res@400':>9} {'t_best':>7} {'cos':>6} {'sec':>6}")
+for bi in range(2):
+ idx, y = get_batch('train', B, T)
+ torch.backends.cuda.matmul.allow_tf32 = False
+ ref = bptt_step(blk, idx, y, 400, 0.1)
+ for tf32 in (False, True):
+ torch.backends.cuda.matmul.allow_tf32 = tf32
+ torch.backends.cudnn.allow_tf32 = tf32
+ xin = blk.embed(idx).detach()
+ zs = relax(blk, xin.clone(), xin, 400, 0.1)
+ res = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item()
+ for ck in (1, 2, 3):
+ t0 = time.time()
+ a, tb = holo_a_select(blk, zs, xin, y, 2, 0.02, 120, 0.1, corr_every=ck)
+ dt = time.time() - t0
+ print(f"{str(tf32):>5} {ck:>8} {res:>9.1e} {tb:>7} {cos(gfrom(idx, zs, a), ref, blk.block):>6.3f} {dt:>6.1f}", flush=True)