diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/fast_probe.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/fast_probe.py')
| -rw-r--r-- | ep_run/fast_probe.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/ep_run/fast_probe.py b/ep_run/fast_probe.py new file mode 100644 index 0000000..6fa72a2 --- /dev/null +++ b/ep_run/fast_probe.py @@ -0,0 +1,41 @@ +"""Validate the speed knobs: corr_every (stale AEP corr) x tf32, vs fp32 exact reference. +Watch BOTH gradient cosine AND the achievable free-phase residual (tf32 may raise the res floor +above res_est=1e-4 -> validity issue).""" +import time, torch +from lt_ep_train import EQBlock, get_batch, bptt_step, relax +from holo_ep import holo_a_select +torch.manual_seed(0) +B, T = 16, 64 +blk = EQBlock(128, 4, 256, T, attn_mode='thick') +for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')): + with torch.no_grad(): + p.copy_(w.to('cuda')) + +def cos(ga, gb, ps): + keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] + va = torch.cat([ga[id(p)].reshape(-1) for p in keep]); vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + +def gfrom(idx, zs, a): + with torch.enable_grad(): + xin = blk.embed(idx) + f = blk.force(zs.detach(), xin, cg=True) + g = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True) + return {id(p): gv for p, gv in zip(blk.block, g)} + +print(f"{'tf32':>5} {'corr_ev':>8} {'res@400':>9} {'t_best':>7} {'cos':>6} {'sec':>6}") +for bi in range(2): + idx, y = get_batch('train', B, T) + torch.backends.cuda.matmul.allow_tf32 = False + ref = bptt_step(blk, idx, y, 400, 0.1) + for tf32 in (False, True): + torch.backends.cuda.matmul.allow_tf32 = tf32 + torch.backends.cudnn.allow_tf32 = tf32 + xin = blk.embed(idx).detach() + zs = relax(blk, xin.clone(), xin, 400, 0.1) + res = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item() + for ck in (1, 2, 3): + t0 = time.time() + a, tb = holo_a_select(blk, zs, xin, y, 2, 0.02, 120, 0.1, corr_every=ck) + dt = time.time() - t0 + print(f"{str(tf32):>5} {ck:>8} {res:>9.1e} {tb:>7} {cos(gfrom(idx, zs, a), ref, blk.block):>6.3f} {dt:>6.1f}", flush=True) |
