summaryrefslogtreecommitdiff
path: root/ep_run/test_aselect_deepdive.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/test_aselect_deepdive.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/test_aselect_deepdive.py')
-rw-r--r--ep_run/test_aselect_deepdive.py323
1 files changed, 323 insertions, 0 deletions
diff --git a/ep_run/test_aselect_deepdive.py b/ep_run/test_aselect_deepdive.py
new file mode 100644
index 0000000..54d2776
--- /dev/null
+++ b/ep_run/test_aselect_deepdive.py
@@ -0,0 +1,323 @@
+#!/usr/bin/env python3
+"""Standalone EP a-select performance/correctness probes.
+
+Does not modify trainer files. It can run on CPU if CUDA is unavailable; CUDA timing is only
+attempted when torch.cuda.is_available().
+"""
+import argparse, math, os, time, traceback
+import torch
+import torch.nn.functional as F
+import torch.func as tf
+import lt_ep_train as LT
+import holo_ep as H
+
+
+def cosine(a, b):
+ a = a.detach().reshape(-1).float(); b = b.detach().reshape(-1).float()
+ return float((a @ b) / (a.norm() * b.norm() + 1e-30))
+
+
+def max_rel(a, b):
+ return float((a-b).abs().max() / (b.abs().max() + 1e-12))
+
+
+def ln_jvp(x, dx, gamma, beta=None, eps=1e-5):
+ # Matches PyTorch layer_norm over the last dim (biased variance, affine gamma).
+ mu = x.mean(dim=-1, keepdim=True)
+ xc = x - mu
+ inv = torch.rsqrt((xc * xc).mean(dim=-1, keepdim=True) + eps)
+ xhat = xc * inv
+ dmu = dx.mean(dim=-1, keepdim=True)
+ # mean(xhat * dx), not mean(xhat * (dx-dmu)); mean(xhat)==0.
+ proj = (xhat * dx).mean(dim=-1, keepdim=True)
+ dy = inv * (dx - dmu - xhat * proj)
+ return dy * gamma
+
+
+def ln_vjp(x, gy, gamma, eps=1e-5):
+ mu = x.mean(dim=-1, keepdim=True)
+ xc = x - mu
+ inv = torch.rsqrt((xc * xc).mean(dim=-1, keepdim=True) + eps)
+ xhat = xc * inv
+ g = gy * gamma
+ return inv * (g - g.mean(dim=-1, keepdim=True) - xhat * (g * xhat).mean(dim=-1, keepdim=True))
+
+
+def rms_jvp(x, dx, eps=1e-6):
+ inv = torch.rsqrt((x*x).mean(dim=-1, keepdim=True) + eps)
+ return inv * dx - x * (inv ** 3) * (x * dx).mean(dim=-1, keepdim=True)
+
+
+def rms_vjp(x, gy, eps=1e-6):
+ # RMSNorm Jacobian is symmetric.
+ inv = torch.rsqrt((x*x).mean(dim=-1, keepdim=True) + eps)
+ return inv * gy - x * (inv ** 3) * (x * gy).mean(dim=-1, keepdim=True)
+
+
+def gelu_tanh_deriv(x):
+ # derivative of F.gelu(x, approximate='tanh')
+ k = 0.7978845608028654
+ a = 0.044715
+ u = k * (x + a * x * x * x)
+ t = torch.tanh(u)
+ return 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * k * (1.0 + 3.0 * a * x * x)
+
+
+def manual_nc_jvp_vjp_thick(blk, z, vec):
+ """Explicit fp32 Jv and J^T vec for blk.nc_force(z), thick + qknorm path.
+
+ This is deliberately written as plain ATen ops: no torch.func, no autograd. It assumes
+ blk.fnoise == 0 and attn_mode == 'thick'. It returns (Jv, JTv) for the same input vector.
+ """
+ assert blk.attn_mode == 'thick'
+ assert getattr(blk, 'fnoise', 0.0) == 0.0
+ B, T, C = z.shape
+ Hh, dh = blk.H, blk.dh
+ scale = 1.0 / math.sqrt(dh)
+
+ # Base forward intermediates.
+ h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b)
+ h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b)
+ q0 = (h1 @ blk.WQ).view(B, T, Hh, dh).transpose(1, 2)
+ k0 = (h1 @ blk.WK).view(B, T, Hh, dh).transpose(1, 2)
+ vv = (h1 @ blk.WV).view(B, T, Hh, dh).transpose(1, 2)
+ if getattr(blk, 'qknorm', False):
+ q = q0 * torch.rsqrt(q0.pow(2).mean(-1, keepdim=True) + 1e-6)
+ k = k0 * torch.rsqrt(k0.pow(2).mean(-1, keepdim=True) + 1e-6)
+ else:
+ q, k = q0, k0
+ logits = (q @ k.transpose(-2, -1)) * scale
+ p = torch.softmax(logits.masked_fill(~blk.cmask, float('-inf')), -1)
+
+ u = h2 @ blk.fc + blk.fcb
+ gp = gelu_tanh_deriv(u)
+
+ # JVP: attention branch.
+ dh1 = ln_jvp(z, vec, blk.ln1g)
+ dq0 = (dh1 @ blk.WQ).view(B, T, Hh, dh).transpose(1, 2)
+ dk0 = (dh1 @ blk.WK).view(B, T, Hh, dh).transpose(1, 2)
+ dvv = (dh1 @ blk.WV).view(B, T, Hh, dh).transpose(1, 2)
+ if getattr(blk, 'qknorm', False):
+ dq = rms_jvp(q0, dq0)
+ dk = rms_jvp(k0, dk0)
+ else:
+ dq, dk = dq0, dk0
+ dlogits = (dq @ k.transpose(-2, -1) + q @ dk.transpose(-2, -1)) * scale
+ dp = p * (dlogits - (p * dlogits).sum(-1, keepdim=True))
+ datt_heads = dp @ vv + p @ dvv
+ Jv_att = datt_heads.transpose(1, 2).reshape(B, T, C) @ blk.WO
+
+ # JVP: FFN branch.
+ dh2 = ln_jvp(z, vec, blk.ln2g)
+ du = dh2 @ blk.fc
+ Jv_ff = (du * gp) @ blk.pj
+ Jv = Jv_att + Jv_ff
+
+ # VJP: attention branch.
+ gout = vec
+ gh_att_heads = (gout @ blk.WO.t()).view(B, T, Hh, dh).transpose(1, 2)
+ gp_soft = gh_att_heads @ vv.transpose(-2, -1)
+ gv_heads = p.transpose(-2, -1) @ gh_att_heads
+ glogits = p * (gp_soft - (gp_soft * p).sum(-1, keepdim=True))
+ gq = (glogits @ k) * scale
+ gk = (glogits.transpose(-2, -1) @ q) * scale
+ if getattr(blk, 'qknorm', False):
+ gq0 = rms_vjp(q0, gq)
+ gk0 = rms_vjp(k0, gk)
+ else:
+ gq0, gk0 = gq, gk
+ gh1 = (gq0.transpose(1, 2).reshape(B, T, C) @ blk.WQ.t()
+ + gk0.transpose(1, 2).reshape(B, T, C) @ blk.WK.t()
+ + gv_heads.transpose(1, 2).reshape(B, T, C) @ blk.WV.t())
+ JTv_att = ln_vjp(z, gh1, blk.ln1g)
+
+ # VJP: FFN branch.
+ gg = gout @ blk.pj.t()
+ gu = gg * gp
+ gh2 = gu @ blk.fc.t()
+ JTv_ff = ln_vjp(z, gh2, blk.ln2g)
+ JTv = JTv_att + JTv_ff
+ return Jv, JTv
+
+
+def make_block(args, device):
+ if args.tiny:
+ # Tiny block for compiler/frontend feasibility tests.
+ blk = LT.EQBlock(32, 4, 64, 32, c=1.0, attn_mode='thick')
+ B = args.B or 2
+ T = 32
+ y_vocab = LT.vocab
+ idx = torch.randint(0, y_vocab, (B, T), device=device)
+ y = torch.randint(0, y_vocab, (B, T), device=device)
+ blk.qknorm = True; blk.track = True; blk.navg = 1; blk.li_avg = 0
+ return blk, idx, y
+ blk = LT.EQBlock(512, 16, 256, 256, c=1.0, attn_mode='thick')
+ blk.qknorm = True; blk.track = True; blk.navg = 1; blk.li_avg = 0
+ ck = torch.load(args.ckpt, map_location=device)
+ with torch.no_grad():
+ for p, s in zip(blk.allp, ck['allp']):
+ p.copy_(s.to(device))
+ B = args.B or 1
+ idx, y = LT.get_batch('train', B, 256)
+ return blk, idx, y
+
+
+def sync(device):
+ if device.type == 'cuda':
+ torch.cuda.synchronize(device)
+
+
+def time_call(fn, device, repeat=3):
+ # one warmup
+ out = fn(); sync(device)
+ ts=[]
+ for _ in range(repeat):
+ t0=time.time(); out=fn(); sync(device); ts.append(time.time()-t0)
+ return min(ts), out
+
+
+def make_tf_step(blk, zs, xin, y, r, eps):
+ B = zs.size(0)
+ X2 = torch.cat([xin, xin], 0)
+ y2 = torch.cat([y, y], 0)
+ sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0)
+ def step(Z):
+ zbar = 0.5 * (Z[:B] + Z[B:])
+ zb2 = torch.cat([zbar, zbar], 0)
+ f = H.rforce(blk, Z, X2) - sg * H.rgrad_ce(blk, Z, y2, denom=y.numel())
+ v = (Z - zb2).contiguous()
+ fnc = lambda zz: blk.nc_force(zz)
+ _, Jv = tf.jvp(fnc, (zb2,), (v,))
+ JTv = tf.vjp(fnc, zb2)[1](v)[0]
+ return Z + eps * (f - (Jv - JTv))
+ return step
+
+
+def make_manual_step(blk, zs, xin, y, r, eps):
+ B = zs.size(0)
+ X2 = torch.cat([xin, xin], 0)
+ y2 = torch.cat([y, y], 0)
+ sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0)
+ def step(Z):
+ zbar = 0.5 * (Z[:B] + Z[B:])
+ zb2 = torch.cat([zbar, zbar], 0)
+ f = H.rforce(blk, Z, X2) - sg * H.rgrad_ce(blk, Z, y2, denom=y.numel())
+ v = (Z - zb2).contiguous()
+ Jv, JTv = manual_nc_jvp_vjp_thick(blk, zb2, v)
+ return Z + eps * (f - (Jv - JTv))
+ return step
+
+
+def run_loop_from_step(step, zs, r, T2, K=10):
+ B = zs.size(0)
+ Z = torch.cat([zs, zs], 0)
+ a_prev = a_best = None
+ inc_min = float('inf'); t_best = 0
+ for t in range(1, T2+1):
+ Z = step(Z)
+ if t % K == 0 or t == T2:
+ a_t = (Z[B:] - Z[:B]) / (2*r)
+ if a_prev is not None:
+ inc = (a_t - a_prev).norm().item()
+ if inc < inc_min:
+ inc_min, a_best, t_best = inc, a_t, t
+ a_prev = a_t
+ if a_best is None:
+ a_best = a_prev; t_best = T2
+ return a_best.detach(), t_best
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--ckpt', default='/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt')
+ ap.add_argument('--tiny', action='store_true')
+ ap.add_argument('--B', type=int, default=None)
+ ap.add_argument('--T1', type=int, default=2)
+ ap.add_argument('--T2', type=int, default=2)
+ ap.add_argument('--r', type=float, default=0.02)
+ ap.add_argument('--eps', type=float, default=0.1)
+ ap.add_argument('--compile', action='store_true')
+ ap.add_argument('--cuda-graph', action='store_true')
+ args = ap.parse_args()
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print('torch', torch.__version__, 'cuda_runtime', torch.version.cuda, 'cuda_available', torch.cuda.is_available(), 'device', device, flush=True)
+ if device.type == 'cuda':
+ print(torch.cuda.get_device_name(device), flush=True)
+
+ torch.manual_seed(0)
+ blk, idx, y = make_block(args, device)
+ print('block', blk.C, blk.H, blk.T, 'B', idx.size(0), 'qknorm', getattr(blk,'qknorm',False), flush=True)
+ xin = blk.embed(idx).detach()
+ zs = LT.relax(blk, xin.clone(), xin, args.T1, args.eps)
+ print('zs norm', float(zs.norm()), flush=True)
+
+ B=zs.size(0)
+ Z0 = torch.cat([zs, zs], 0)
+ zbar = 0.5*(Z0[:B]+Z0[B:]); zb2 = torch.cat([zbar,zbar],0)
+ # Need a nonzero v for the one-off J test.
+ vtest = torch.randn_like(zb2) * 1e-3
+ _, Jv_ref = tf.jvp(lambda zz: blk.nc_force(zz), (zb2,), (vtest,))
+ JTv_ref = tf.vjp(lambda zz: blk.nc_force(zz), zb2)[1](vtest)[0]
+ Jv_m, JTv_m = manual_nc_jvp_vjp_thick(blk, zb2, vtest)
+ print('manual_Jv cos', cosine(Jv_ref, Jv_m), 'maxrel', max_rel(Jv_m, Jv_ref), flush=True)
+ print('manual_JTv cos', cosine(JTv_ref, JTv_m), 'maxrel', max_rel(JTv_m, JTv_ref), flush=True)
+
+ tf_step = make_tf_step(blk, zs, xin, y, args.r, args.eps)
+ man_step = make_manual_step(blk, zs, xin, y, args.r, args.eps)
+ with torch.no_grad():
+ Z_tf = tf_step(Z0)
+ Z_man = man_step(Z0)
+ print('one_step manual vs tf cos', cosine(Z_tf, Z_man), 'max_abs', float((Z_tf-Z_man).abs().max()), 'maxrel', max_rel(Z_man, Z_tf), flush=True)
+
+ # Compare a-select outputs for small T2. Full checkpoint on CPU is intentionally small T2.
+ with torch.no_grad():
+ t0=time.time(); a_base,tb=H.holo_a_track(blk,zs,xin,y,args.r,args.T2,args.eps,K=max(1,args.T2)); sync(device); dt=time.time()-t0
+ t0=time.time(); a_man,tm=run_loop_from_step(man_step,zs,args.r,args.T2,K=max(1,args.T2)); sync(device); dtm=time.time()-t0
+ print('baseline_holo_a_track T2', args.T2, 't_best', tb, 'sec', dt, flush=True)
+ print('manual_loop T2', args.T2, 't_best', tm, 'sec', dtm, 'cos(a)', cosine(a_base,a_man), 'maxrel', max_rel(a_man,a_base), flush=True)
+
+ if args.compile:
+ for name, step in [('tf_step_body', tf_step), ('manual_step_body', man_step)]:
+ try:
+ print('compile start', name, flush=True)
+ cstep = torch.compile(step, fullgraph=True, mode='reduce-overhead')
+ # compile on first invocation
+ with torch.no_grad():
+ Zc = cstep(Z0)
+ sync(device)
+ print('compile ok', name, 'cos one_step', cosine(Z_tf if name=='tf_step_body' else Z_man, Zc), 'max_abs', float(((Z_tf if name=='tf_step_body' else Z_man)-Zc).abs().max()), flush=True)
+ if device.type == 'cuda':
+ t_e,_=time_call(lambda: step(Z0), device)
+ t_c,_=time_call(lambda: cstep(Z0), device)
+ print('timing', name, 'eager_ms', t_e*1000, 'compiled_ms', t_c*1000, 'speedup', t_e/t_c, flush=True)
+ except Exception as e:
+ print('compile FAIL', name, type(e).__name__, str(e)[:1000], flush=True)
+ traceback.print_exc(limit=8)
+
+ if args.cuda_graph:
+ if device.type != 'cuda':
+ print('cuda graph skipped: torch.cuda.is_available() is False', flush=True)
+ else:
+ try:
+ static_Z = Z0.clone()
+ # warmup on a side stream to settle allocations
+ s = torch.cuda.Stream()
+ s.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(s):
+ for _ in range(3):
+ static_out = tf_step(static_Z)
+ torch.cuda.current_stream().wait_stream(s)
+ g = torch.cuda.CUDAGraph()
+ with torch.cuda.graph(g):
+ static_out = tf_step(static_Z)
+ g.replay(); sync(device)
+ print('cuda graph capture ok tf_step_body', cosine(tf_step(Z0), static_out), flush=True)
+ except Exception as e:
+ print('cuda graph FAIL tf_step_body', type(e).__name__, str(e)[:1000], flush=True)
+ traceback.print_exc(limit=8)
+
+if __name__ == '__main__':
+ main()
+
+# NOTE: extra decomposed-layernorm helper kept below for reference; not used by main() above.