diff options
Diffstat (limited to 'ep_run/test_aselect_deepdive.py')
| -rw-r--r-- | ep_run/test_aselect_deepdive.py | 323 |
1 files changed, 323 insertions, 0 deletions
diff --git a/ep_run/test_aselect_deepdive.py b/ep_run/test_aselect_deepdive.py new file mode 100644 index 0000000..54d2776 --- /dev/null +++ b/ep_run/test_aselect_deepdive.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python3 +"""Standalone EP a-select performance/correctness probes. + +Does not modify trainer files. It can run on CPU if CUDA is unavailable; CUDA timing is only +attempted when torch.cuda.is_available(). +""" +import argparse, math, os, time, traceback +import torch +import torch.nn.functional as F +import torch.func as tf +import lt_ep_train as LT +import holo_ep as H + + +def cosine(a, b): + a = a.detach().reshape(-1).float(); b = b.detach().reshape(-1).float() + return float((a @ b) / (a.norm() * b.norm() + 1e-30)) + + +def max_rel(a, b): + return float((a-b).abs().max() / (b.abs().max() + 1e-12)) + + +def ln_jvp(x, dx, gamma, beta=None, eps=1e-5): + # Matches PyTorch layer_norm over the last dim (biased variance, affine gamma). + mu = x.mean(dim=-1, keepdim=True) + xc = x - mu + inv = torch.rsqrt((xc * xc).mean(dim=-1, keepdim=True) + eps) + xhat = xc * inv + dmu = dx.mean(dim=-1, keepdim=True) + # mean(xhat * dx), not mean(xhat * (dx-dmu)); mean(xhat)==0. + proj = (xhat * dx).mean(dim=-1, keepdim=True) + dy = inv * (dx - dmu - xhat * proj) + return dy * gamma + + +def ln_vjp(x, gy, gamma, eps=1e-5): + mu = x.mean(dim=-1, keepdim=True) + xc = x - mu + inv = torch.rsqrt((xc * xc).mean(dim=-1, keepdim=True) + eps) + xhat = xc * inv + g = gy * gamma + return inv * (g - g.mean(dim=-1, keepdim=True) - xhat * (g * xhat).mean(dim=-1, keepdim=True)) + + +def rms_jvp(x, dx, eps=1e-6): + inv = torch.rsqrt((x*x).mean(dim=-1, keepdim=True) + eps) + return inv * dx - x * (inv ** 3) * (x * dx).mean(dim=-1, keepdim=True) + + +def rms_vjp(x, gy, eps=1e-6): + # RMSNorm Jacobian is symmetric. + inv = torch.rsqrt((x*x).mean(dim=-1, keepdim=True) + eps) + return inv * gy - x * (inv ** 3) * (x * gy).mean(dim=-1, keepdim=True) + + +def gelu_tanh_deriv(x): + # derivative of F.gelu(x, approximate='tanh') + k = 0.7978845608028654 + a = 0.044715 + u = k * (x + a * x * x * x) + t = torch.tanh(u) + return 0.5 * (1.0 + t) + 0.5 * x * (1.0 - t * t) * k * (1.0 + 3.0 * a * x * x) + + +def manual_nc_jvp_vjp_thick(blk, z, vec): + """Explicit fp32 Jv and J^T vec for blk.nc_force(z), thick + qknorm path. + + This is deliberately written as plain ATen ops: no torch.func, no autograd. It assumes + blk.fnoise == 0 and attn_mode == 'thick'. It returns (Jv, JTv) for the same input vector. + """ + assert blk.attn_mode == 'thick' + assert getattr(blk, 'fnoise', 0.0) == 0.0 + B, T, C = z.shape + Hh, dh = blk.H, blk.dh + scale = 1.0 / math.sqrt(dh) + + # Base forward intermediates. + h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b) + h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b) + q0 = (h1 @ blk.WQ).view(B, T, Hh, dh).transpose(1, 2) + k0 = (h1 @ blk.WK).view(B, T, Hh, dh).transpose(1, 2) + vv = (h1 @ blk.WV).view(B, T, Hh, dh).transpose(1, 2) + if getattr(blk, 'qknorm', False): + q = q0 * torch.rsqrt(q0.pow(2).mean(-1, keepdim=True) + 1e-6) + k = k0 * torch.rsqrt(k0.pow(2).mean(-1, keepdim=True) + 1e-6) + else: + q, k = q0, k0 + logits = (q @ k.transpose(-2, -1)) * scale + p = torch.softmax(logits.masked_fill(~blk.cmask, float('-inf')), -1) + + u = h2 @ blk.fc + blk.fcb + gp = gelu_tanh_deriv(u) + + # JVP: attention branch. + dh1 = ln_jvp(z, vec, blk.ln1g) + dq0 = (dh1 @ blk.WQ).view(B, T, Hh, dh).transpose(1, 2) + dk0 = (dh1 @ blk.WK).view(B, T, Hh, dh).transpose(1, 2) + dvv = (dh1 @ blk.WV).view(B, T, Hh, dh).transpose(1, 2) + if getattr(blk, 'qknorm', False): + dq = rms_jvp(q0, dq0) + dk = rms_jvp(k0, dk0) + else: + dq, dk = dq0, dk0 + dlogits = (dq @ k.transpose(-2, -1) + q @ dk.transpose(-2, -1)) * scale + dp = p * (dlogits - (p * dlogits).sum(-1, keepdim=True)) + datt_heads = dp @ vv + p @ dvv + Jv_att = datt_heads.transpose(1, 2).reshape(B, T, C) @ blk.WO + + # JVP: FFN branch. + dh2 = ln_jvp(z, vec, blk.ln2g) + du = dh2 @ blk.fc + Jv_ff = (du * gp) @ blk.pj + Jv = Jv_att + Jv_ff + + # VJP: attention branch. + gout = vec + gh_att_heads = (gout @ blk.WO.t()).view(B, T, Hh, dh).transpose(1, 2) + gp_soft = gh_att_heads @ vv.transpose(-2, -1) + gv_heads = p.transpose(-2, -1) @ gh_att_heads + glogits = p * (gp_soft - (gp_soft * p).sum(-1, keepdim=True)) + gq = (glogits @ k) * scale + gk = (glogits.transpose(-2, -1) @ q) * scale + if getattr(blk, 'qknorm', False): + gq0 = rms_vjp(q0, gq) + gk0 = rms_vjp(k0, gk) + else: + gq0, gk0 = gq, gk + gh1 = (gq0.transpose(1, 2).reshape(B, T, C) @ blk.WQ.t() + + gk0.transpose(1, 2).reshape(B, T, C) @ blk.WK.t() + + gv_heads.transpose(1, 2).reshape(B, T, C) @ blk.WV.t()) + JTv_att = ln_vjp(z, gh1, blk.ln1g) + + # VJP: FFN branch. + gg = gout @ blk.pj.t() + gu = gg * gp + gh2 = gu @ blk.fc.t() + JTv_ff = ln_vjp(z, gh2, blk.ln2g) + JTv = JTv_att + JTv_ff + return Jv, JTv + + +def make_block(args, device): + if args.tiny: + # Tiny block for compiler/frontend feasibility tests. + blk = LT.EQBlock(32, 4, 64, 32, c=1.0, attn_mode='thick') + B = args.B or 2 + T = 32 + y_vocab = LT.vocab + idx = torch.randint(0, y_vocab, (B, T), device=device) + y = torch.randint(0, y_vocab, (B, T), device=device) + blk.qknorm = True; blk.track = True; blk.navg = 1; blk.li_avg = 0 + return blk, idx, y + blk = LT.EQBlock(512, 16, 256, 256, c=1.0, attn_mode='thick') + blk.qknorm = True; blk.track = True; blk.navg = 1; blk.li_avg = 0 + ck = torch.load(args.ckpt, map_location=device) + with torch.no_grad(): + for p, s in zip(blk.allp, ck['allp']): + p.copy_(s.to(device)) + B = args.B or 1 + idx, y = LT.get_batch('train', B, 256) + return blk, idx, y + + +def sync(device): + if device.type == 'cuda': + torch.cuda.synchronize(device) + + +def time_call(fn, device, repeat=3): + # one warmup + out = fn(); sync(device) + ts=[] + for _ in range(repeat): + t0=time.time(); out=fn(); sync(device); ts.append(time.time()-t0) + return min(ts), out + + +def make_tf_step(blk, zs, xin, y, r, eps): + B = zs.size(0) + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) + def step(Z): + zbar = 0.5 * (Z[:B] + Z[B:]) + zb2 = torch.cat([zbar, zbar], 0) + f = H.rforce(blk, Z, X2) - sg * H.rgrad_ce(blk, Z, y2, denom=y.numel()) + v = (Z - zb2).contiguous() + fnc = lambda zz: blk.nc_force(zz) + _, Jv = tf.jvp(fnc, (zb2,), (v,)) + JTv = tf.vjp(fnc, zb2)[1](v)[0] + return Z + eps * (f - (Jv - JTv)) + return step + + +def make_manual_step(blk, zs, xin, y, r, eps): + B = zs.size(0) + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) + def step(Z): + zbar = 0.5 * (Z[:B] + Z[B:]) + zb2 = torch.cat([zbar, zbar], 0) + f = H.rforce(blk, Z, X2) - sg * H.rgrad_ce(blk, Z, y2, denom=y.numel()) + v = (Z - zb2).contiguous() + Jv, JTv = manual_nc_jvp_vjp_thick(blk, zb2, v) + return Z + eps * (f - (Jv - JTv)) + return step + + +def run_loop_from_step(step, zs, r, T2, K=10): + B = zs.size(0) + Z = torch.cat([zs, zs], 0) + a_prev = a_best = None + inc_min = float('inf'); t_best = 0 + for t in range(1, T2+1): + Z = step(Z) + if t % K == 0 or t == T2: + a_t = (Z[B:] - Z[:B]) / (2*r) + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + a_prev = a_t + if a_best is None: + a_best = a_prev; t_best = T2 + return a_best.detach(), t_best + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--ckpt', default='/home/yurenh2/ept/ep_run/runs/ep_resreg_warm.pt') + ap.add_argument('--tiny', action='store_true') + ap.add_argument('--B', type=int, default=None) + ap.add_argument('--T1', type=int, default=2) + ap.add_argument('--T2', type=int, default=2) + ap.add_argument('--r', type=float, default=0.02) + ap.add_argument('--eps', type=float, default=0.1) + ap.add_argument('--compile', action='store_true') + ap.add_argument('--cuda-graph', action='store_true') + args = ap.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print('torch', torch.__version__, 'cuda_runtime', torch.version.cuda, 'cuda_available', torch.cuda.is_available(), 'device', device, flush=True) + if device.type == 'cuda': + print(torch.cuda.get_device_name(device), flush=True) + + torch.manual_seed(0) + blk, idx, y = make_block(args, device) + print('block', blk.C, blk.H, blk.T, 'B', idx.size(0), 'qknorm', getattr(blk,'qknorm',False), flush=True) + xin = blk.embed(idx).detach() + zs = LT.relax(blk, xin.clone(), xin, args.T1, args.eps) + print('zs norm', float(zs.norm()), flush=True) + + B=zs.size(0) + Z0 = torch.cat([zs, zs], 0) + zbar = 0.5*(Z0[:B]+Z0[B:]); zb2 = torch.cat([zbar,zbar],0) + # Need a nonzero v for the one-off J test. + vtest = torch.randn_like(zb2) * 1e-3 + _, Jv_ref = tf.jvp(lambda zz: blk.nc_force(zz), (zb2,), (vtest,)) + JTv_ref = tf.vjp(lambda zz: blk.nc_force(zz), zb2)[1](vtest)[0] + Jv_m, JTv_m = manual_nc_jvp_vjp_thick(blk, zb2, vtest) + print('manual_Jv cos', cosine(Jv_ref, Jv_m), 'maxrel', max_rel(Jv_m, Jv_ref), flush=True) + print('manual_JTv cos', cosine(JTv_ref, JTv_m), 'maxrel', max_rel(JTv_m, JTv_ref), flush=True) + + tf_step = make_tf_step(blk, zs, xin, y, args.r, args.eps) + man_step = make_manual_step(blk, zs, xin, y, args.r, args.eps) + with torch.no_grad(): + Z_tf = tf_step(Z0) + Z_man = man_step(Z0) + print('one_step manual vs tf cos', cosine(Z_tf, Z_man), 'max_abs', float((Z_tf-Z_man).abs().max()), 'maxrel', max_rel(Z_man, Z_tf), flush=True) + + # Compare a-select outputs for small T2. Full checkpoint on CPU is intentionally small T2. + with torch.no_grad(): + t0=time.time(); a_base,tb=H.holo_a_track(blk,zs,xin,y,args.r,args.T2,args.eps,K=max(1,args.T2)); sync(device); dt=time.time()-t0 + t0=time.time(); a_man,tm=run_loop_from_step(man_step,zs,args.r,args.T2,K=max(1,args.T2)); sync(device); dtm=time.time()-t0 + print('baseline_holo_a_track T2', args.T2, 't_best', tb, 'sec', dt, flush=True) + print('manual_loop T2', args.T2, 't_best', tm, 'sec', dtm, 'cos(a)', cosine(a_base,a_man), 'maxrel', max_rel(a_man,a_base), flush=True) + + if args.compile: + for name, step in [('tf_step_body', tf_step), ('manual_step_body', man_step)]: + try: + print('compile start', name, flush=True) + cstep = torch.compile(step, fullgraph=True, mode='reduce-overhead') + # compile on first invocation + with torch.no_grad(): + Zc = cstep(Z0) + sync(device) + print('compile ok', name, 'cos one_step', cosine(Z_tf if name=='tf_step_body' else Z_man, Zc), 'max_abs', float(((Z_tf if name=='tf_step_body' else Z_man)-Zc).abs().max()), flush=True) + if device.type == 'cuda': + t_e,_=time_call(lambda: step(Z0), device) + t_c,_=time_call(lambda: cstep(Z0), device) + print('timing', name, 'eager_ms', t_e*1000, 'compiled_ms', t_c*1000, 'speedup', t_e/t_c, flush=True) + except Exception as e: + print('compile FAIL', name, type(e).__name__, str(e)[:1000], flush=True) + traceback.print_exc(limit=8) + + if args.cuda_graph: + if device.type != 'cuda': + print('cuda graph skipped: torch.cuda.is_available() is False', flush=True) + else: + try: + static_Z = Z0.clone() + # warmup on a side stream to settle allocations + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + static_out = tf_step(static_Z) + torch.cuda.current_stream().wait_stream(s) + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + static_out = tf_step(static_Z) + g.replay(); sync(device) + print('cuda graph capture ok tf_step_body', cosine(tf_step(Z0), static_out), flush=True) + except Exception as e: + print('cuda graph FAIL tf_step_body', type(e).__name__, str(e)[:1000], flush=True) + traceback.print_exc(limit=8) + +if __name__ == '__main__': + main() + +# NOTE: extra decomposed-layernorm helper kept below for reference; not used by main() above. |
