#!/usr/bin/env python3 """Standalone EP a-select performance/correctness probes. Does not modify trainer files. It can run on CPU if CUDA is unavailable; CUDA timing is only attempted when torch.cuda.is_available(). """ import argparse, math, os, time, traceback import torch import torch.nn.functional as F import torch.func as tf import lt_ep_train as LT import holo_ep as H def cosine(a, b): a = a.detach().reshape(-1).float(); b = b.detach().reshape(-1).float() return float((a @ b) / (a.norm() * b.norm() + 1e-30)) def max_rel(a, b): return float((a-b).abs().max() / (b.abs().max() + 1e-12)) def ln_jvp(x, dx, gamma, beta=None, eps=1e-5): # Matches PyTorch layer_norm over the last dim (biased variance, affine gamma). mu = x.mean(dim=-1, keepdim=True) xc = x - mu inv = torch.rsqrt((xc * xc).mean(dim=-1, keepdim=True) + eps) xhat = xc * inv dmu = dx.mean(dim=-1, keepdim=True) # mean(xhat * dx), not mean(xhat * (dx-dmu)); mean(xhat)==0. proj = (xhat * dx).mean(dim=-1, keepdim=True) dy = inv * (dx - dmu - xhat * proj) return dy * gamma def ln_vjp(x, gy, gamma, eps=1e-5): mu = x.mean(dim=-1, keepdim=True) xc = x - mu inv = torch.rsqrt((xc * xc).mean(dim=-1, keepdim=True) + eps) xhat = xc * inv g = gy * gamma return inv * (g - g.mean(dim=-1, keepdim=True) - xhat * (g * xhat).mean(dim=-1, keepdim=True)) def rms_jvp(x, dx, eps=1e-6): inv = torch.rsqrt((x*x).mean(dim=-1, keepdim=True) + eps) return inv * dx - x * (inv ** 3) * (x * dx).mean(dim=-1, keepdim=True) def rms_vjp(x, gy, eps=1e-6): # RMSNorm Jacobian is symmetric. inv = torch.rsqrt((x*x).mean(dim=-1, keepdim=True) + eps) return inv * gy - x * (inv ** 3) * (x * gy).mean(dim=-1, keepdim=True) def gelu_tanh_deriv(x): # derivative of F.gelu(x, approximate='tanh') k = 0.7978845608028654 a = 0.044715 u = k * (x + a * x * x * x) t = torch.tanh(u) return 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * k * (1.0 + 3.0 * a * x * x) def manual_nc_jvp_vjp_thick(blk, z, vec): """Explicit fp32 Jv and J^T vec for blk.nc_force(z), thick + qknorm path. This is deliberately written as plain ATen ops: no torch.func, no autograd. It assumes blk.fnoise == 0 and attn_mode == 'thick'. It returns (Jv, JTv) for the same input vector. """ assert blk.attn_mode == 'thick' assert getattr(blk, 'fnoise', 0.0) == 0.0 B, T, C = z.shape Hh, dh = blk.H, blk.dh scale = 1.0 / math.sqrt(dh) # Base forward intermediates. h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b) h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b) q0 = (h1 @ blk.WQ).view(B, T, Hh, dh).transpose(1, 2) k0 = (h1 @ blk.WK).view(B, T, Hh, dh).transpose(1, 2) vv = (h1 @ blk.WV).view(B, T, Hh, dh).transpose(1, 2) if getattr(blk, 'qknorm', False): q = q0 * torch.rsqrt(q0.pow(2).mean(-1, keepdim=True) + 1e-6) k = k0 * torch.rsqrt(k0.pow(2).mean(-1, keepdim=True) + 1e-6) else: q, k = q0, k0 logits = (q @ k.transpose(-2, -1)) * scale p = torch.softmax(logits.masked_fill(~blk.cmask, float('-inf')), -1) u = h2 @ blk.fc + blk.fcb gp = gelu_tanh_deriv(u) # JVP: attention branch. dh1 = ln_jvp(z, vec, blk.ln1g) dq0 = (dh1 @ blk.WQ).view(B, T, Hh, dh).transpose(1, 2) dk0 = (dh1 @ blk.WK).view(B, T, Hh, dh).transpose(1, 2) dvv = (dh1 @ blk.WV).view(B, T, Hh, dh).transpose(1, 2) if getattr(blk, 'qknorm', False): dq = rms_jvp(q0, dq0) dk = rms_jvp(k0, dk0) else: dq, dk = dq0, dk0 dlogits = (dq @ k.transpose(-2, -1) + q @ dk.transpose(-2, -1)) * scale dp = p * (dlogits - (p * dlogits).sum(-1, keepdim=True)) datt_heads = dp @ vv + p @ dvv Jv_att = datt_heads.transpose(1, 2).reshape(B, T, C) @ blk.WO # JVP: FFN branch. dh2 = ln_jvp(z, vec, blk.ln2g) du = dh2 @ blk.fc Jv_ff = (du * gp) @ blk.pj Jv = Jv_att + Jv_ff # VJP: attention branch. gout = vec gh_att_heads = (gout @ blk.WO.t()).view(B, T, Hh, dh).transpose(1, 2) gp_soft = gh_att_heads @ vv.transpose(-2, -1) gv_heads = p.transpose(-2, -1) @ gh_att_heads glogits = p * (gp_soft - (gp_soft * p).sum(-1, keepdim=True)) gq = (glogits @ k) * scale gk = (glogits.transpose(-2, -1) @ q) * scale if getattr(blk, 'qknorm', False): gq0 = rms_vjp(q0, gq) gk0 = rms_vjp(k0, gk) else: gq0, gk0 = gq, gk gh1 = (gq0.transpose(1, 2).reshape(B, T, C) @ blk.WQ.t() + gk0.transpose(1, 2).reshape(B, T, C) @ blk.WK.t() + gv_heads.transpose(1, 2).reshape(B, T, C) @ blk.WV.t()) JTv_att = ln_vjp(z, gh1, blk.ln1g) # VJP: FFN branch. gg = gout @ blk.pj.t() gu = gg * gp gh2 = gu @ blk.fc.t() JTv_ff = ln_vjp(z, gh2, blk.ln2g) JTv = JTv_att + JTv_ff return Jv, JTv def make_block(args, device): if args.tiny: # Tiny block for compiler/frontend feasibility tests. blk = LT.EQBlock(32, 4, 64, 32, c=1.0, attn_mode='thick') B = args.B or 2 T = 32 y_vocab = LT.vocab idx = torch.randint(0, y_vocab, (B, T), device=device) y = torch.randint(0, y_vocab, (B, T), device=device) blk.qknorm = True; blk.track = True; blk.navg = 1; blk.li_avg = 0 return blk, idx, y blk = LT.EQBlock(512, 16, 256, 256, c=1.0, attn_mode='thick') blk.qknorm = True; blk.track = True; blk.navg = 1; blk.li_avg = 0 ck = torch.load(args.ckpt, map_location=device) with torch.no_grad(): for p, s in zip(blk.allp, ck['allp']): p.copy_(s.to(device)) B = args.B or 1 idx, y = LT.get_batch('train', B, 256) return blk, idx, y def sync(device): if device.type == 'cuda': torch.cuda.synchronize(device) def time_call(fn, device, repeat=3): # one warmup out = fn(); sync(device) ts=[] for _ in range(repeat): t0=time.time(); out=fn(); sync(device); ts.append(time.time()-t0) return min(ts), out def make_tf_step(blk, zs, xin, y, r, eps): B = zs.size(0) X2 = torch.cat([xin, xin], 0) y2 = torch.cat([y, y], 0) sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) def step(Z): zbar = 0.5 * (Z[:B] + Z[B:]) zb2 = torch.cat([zbar, zbar], 0) f = H.rforce(blk, Z, X2) - sg * H.rgrad_ce(blk, Z, y2, denom=y.numel()) v = (Z - zb2).contiguous() fnc = lambda zz: blk.nc_force(zz) _, Jv = tf.jvp(fnc, (zb2,), (v,)) JTv = tf.vjp(fnc, zb2)[1](v)[0] return Z + eps * (f - (Jv - JTv)) return step def make_manual_step(blk, zs, xin, y, r, eps): B = zs.size(0) X2 = torch.cat([xin, xin], 0) y2 = torch.cat([y, y], 0) sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) def step(Z): zbar = 0.5 * (Z[:B] + Z[B:]) zb2 = torch.cat([zbar, zbar], 0) f = H.rforce(blk, Z, X2) - sg * H.rgrad_ce(blk, Z, y2, denom=y.numel()) v = (Z - zb2).contiguous() Jv, JTv = manual_nc_jvp_vjp_thick(blk, zb2, v) return Z + eps * (f - (Jv - JTv)) return step def run_loop_from_step(step, zs, r, T2, K=10): B = zs.size(0) Z = torch.cat([zs, zs], 0) a_prev = a_best = None inc_min = float('inf'); t_best = 0 for t in range(1, T2+1): Z = step(Z) if t % K == 0 or t == T2: a_t = (Z[B:] - Z[:B]) / (2*r) if a_prev is not None: inc = (a_t - a_prev).norm().item() if inc < inc_min: inc_min, a_best, t_best = inc, a_t, t a_prev = a_t if a_best is None: a_best = a_prev; t_best = T2 return a_best.detach(), t_best def main(): ap = argparse.ArgumentParser() ap.add_argument('--ckpt', default='/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt') ap.add_argument('--tiny', action='store_true') ap.add_argument('--B', type=int, default=None) ap.add_argument('--T1', type=int, default=2) ap.add_argument('--T2', type=int, default=2) ap.add_argument('--r', type=float, default=0.02) ap.add_argument('--eps', type=float, default=0.1) ap.add_argument('--compile', action='store_true') ap.add_argument('--cuda-graph', action='store_true') args = ap.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print('torch', torch.__version__, 'cuda_runtime', torch.version.cuda, 'cuda_available', torch.cuda.is_available(), 'device', device, flush=True) if device.type == 'cuda': print(torch.cuda.get_device_name(device), flush=True) torch.manual_seed(0) blk, idx, y = make_block(args, device) print('block', blk.C, blk.H, blk.T, 'B', idx.size(0), 'qknorm', getattr(blk,'qknorm',False), flush=True) xin = blk.embed(idx).detach() zs = LT.relax(blk, xin.clone(), xin, args.T1, args.eps) print('zs norm', float(zs.norm()), flush=True) B=zs.size(0) Z0 = torch.cat([zs, zs], 0) zbar = 0.5*(Z0[:B]+Z0[B:]); zb2 = torch.cat([zbar,zbar],0) # Need a nonzero v for the one-off J test. vtest = torch.randn_like(zb2) * 1e-3 _, Jv_ref = tf.jvp(lambda zz: blk.nc_force(zz), (zb2,), (vtest,)) JTv_ref = tf.vjp(lambda zz: blk.nc_force(zz), zb2)[1](vtest)[0] Jv_m, JTv_m = manual_nc_jvp_vjp_thick(blk, zb2, vtest) print('manual_Jv cos', cosine(Jv_ref, Jv_m), 'maxrel', max_rel(Jv_m, Jv_ref), flush=True) print('manual_JTv cos', cosine(JTv_ref, JTv_m), 'maxrel', max_rel(JTv_m, JTv_ref), flush=True) tf_step = make_tf_step(blk, zs, xin, y, args.r, args.eps) man_step = make_manual_step(blk, zs, xin, y, args.r, args.eps) with torch.no_grad(): Z_tf = tf_step(Z0) Z_man = man_step(Z0) print('one_step manual vs tf cos', cosine(Z_tf, Z_man), 'max_abs', float((Z_tf-Z_man).abs().max()), 'maxrel', max_rel(Z_man, Z_tf), flush=True) # Compare a-select outputs for small T2. Full checkpoint on CPU is intentionally small T2. with torch.no_grad(): t0=time.time(); a_base,tb=H.holo_a_track(blk,zs,xin,y,args.r,args.T2,args.eps,K=max(1,args.T2)); sync(device); dt=time.time()-t0 t0=time.time(); a_man,tm=run_loop_from_step(man_step,zs,args.r,args.T2,K=max(1,args.T2)); sync(device); dtm=time.time()-t0 print('baseline_holo_a_track T2', args.T2, 't_best', tb, 'sec', dt, flush=True) print('manual_loop T2', args.T2, 't_best', tm, 'sec', dtm, 'cos(a)', cosine(a_base,a_man), 'maxrel', max_rel(a_man,a_base), flush=True) if args.compile: for name, step in [('tf_step_body', tf_step), ('manual_step_body', man_step)]: try: print('compile start', name, flush=True) cstep = torch.compile(step, fullgraph=True, mode='reduce-overhead') # compile on first invocation with torch.no_grad(): Zc = cstep(Z0) sync(device) print('compile ok', name, 'cos one_step', cosine(Z_tf if name=='tf_step_body' else Z_man, Zc), 'max_abs', float(((Z_tf if name=='tf_step_body' else Z_man)-Zc).abs().max()), flush=True) if device.type == 'cuda': t_e,_=time_call(lambda: step(Z0), device) t_c,_=time_call(lambda: cstep(Z0), device) print('timing', name, 'eager_ms', t_e*1000, 'compiled_ms', t_c*1000, 'speedup', t_e/t_c, flush=True) except Exception as e: print('compile FAIL', name, type(e).__name__, str(e)[:1000], flush=True) traceback.print_exc(limit=8) if args.cuda_graph: if device.type != 'cuda': print('cuda graph skipped: torch.cuda.is_available() is False', flush=True) else: try: static_Z = Z0.clone() # warmup on a side stream to settle allocations s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): static_out = tf_step(static_Z) torch.cuda.current_stream().wait_stream(s) g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): static_out = tf_step(static_Z) g.replay(); sync(device) print('cuda graph capture ok tf_step_body', cosine(tf_step(Z0), static_out), flush=True) except Exception as e: print('cuda graph FAIL tf_step_body', type(e).__name__, str(e)[:1000], flush=True) traceback.print_exc(limit=8) if __name__ == '__main__': main() # NOTE: extra decomposed-layernorm helper kept below for reference; not used by main() above.