""" Characterize AEP (non-conservative EP) on CET's attention, before porting to the LM. Controlled knob: attention scale s in force_z = -dE_rest/dz + s * RealAttn(z). s=0 -> pure conservative reconstruction (A_J=0; EP exact) s up -> attention dominates the force -> more non-conservative -> naive EP biased. Metric: cosine(EP-grad, BPTT-grad) on the attention params {WQ,WK,WV,WO} (the global cosine is diluted by the dominant conservative params, so we look at attention itself). The AEP correction is -s*(J_A v) on z, J_A = antisym Jacobian of RealAttn at the free eq. Sweeps: (1) s [non-conservativeness], (2) beta [nudge size], (3) T2 [nudge steps], (4) T1 [free-phase convergence]. Plus: free-eq identical naive vs AEP, and cost. """ import argparse, math, time, torch, torch.nn.functional as F from cet_mvp import make_patch_mask, masked_cost, get_loaders from cet_aep import CETReal dev = 'cuda' if torch.cuda.is_available() else 'cpu' ATTN = ('WQ', 'WK', 'WV', 'WO') def force(model, xbar, z, y, s): z = z.requires_grad_(True); y = y.requires_grad_(True) gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True) return -gz + s * model.real_attn(z), -gy def relax_free(model, xbar, z, y, s, T1, eps): for _ in range(T1): with torch.enable_grad(): fz, fy = force(model, xbar, z, y, s) fz, fy = fz.detach(), fy.detach() with torch.no_grad(): z, y = z + eps * fz, y + eps * fy return z.detach(), y.detach() def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep): z, y = zs.clone(), ys.clone() for _ in range(T2): with torch.enable_grad(): fz, fy = force(model, xbar, z, y, s) fz, fy = fz.detach(), fy.detach() yy = y.detach().requires_grad_(True) gy, = torch.autograd.grad(masked_cost(yy, X, M), yy) fy = fy - sign * beta * gy if aep: v = (z - zs).detach() Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1] JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1] fz = fz - s * (Jv - JTv) # -2 * s * 0.5 (J v - J^T v) with torch.no_grad(): z, y = z + eps * fz, y + eps * fy return z.detach(), y.detach() def vf_grad(model, xbar, s, T1, T2, eps, beta, aep): zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps) zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep) zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep) az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach() with torch.enable_grad(): fz, fy = force(model, xbar, zs.detach(), ys.detach(), s) g = torch.autograd.grad((az * fz).sum() + (ay * fy).sum(), list(model.parameters()), allow_unused=True) return zs, g def bptt_grad(model, xbar, s, T1, eps): z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True) for _ in range(T1): fz, fy = force(model, xbar, z, y, s) z, y = z + eps * fz, y + eps * fy return torch.autograd.grad(masked_cost(y, X, M) / M.sum(), list(model.parameters()), allow_unused=True) def attn_cos(g, gb, names): cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None] return sum(cs) / len(cs) def global_cos(g, gb): a = torch.cat([x.flatten() for x in g if x is not None]) b = torch.cat([x.flatten() for x, y in zip(g, gb) if x is not None and y is not None]) return F.cosine_similarity(a, b, dim=0).item() def measure(model, names, s, T1, T2, eps, beta): gb = bptt_grad(model, XBAR, s, T1, eps) zsn, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=False) zsa, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=True) eq_id = (zsn - zsa).norm().item() / (zsn.norm().item() + 1e-9) # free eq identical? return dict(naive=attn_cos(gn, gb, names), aep=attn_cos(ga, gb, names), gnaive=global_cos(gn, gb), gaep=global_cos(ga, gb), eq_id=eq_id) def main(): ap = argparse.ArgumentParser() ap.add_argument('--dataset', default='fashionmnist') ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1) ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7) ap.add_argument('--batch', type=int, default=32) cfg = ap.parse_args() torch.manual_seed(0) model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=64, heads=4, dh=16, mem=128).to(dev) names = [n for n, _ in model.named_parameters()] trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset) global X, M, XBAR X, _ = next(iter(trl)); X = X.to(dev) M = make_patch_mask(X.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) XBAR = X * (1 - M) # intrinsic non-conservativeness of the attention map itself zs, _ = relax_free(model, XBAR, *model.init_state(XBAR), 1.0, 120, 0.2) v = torch.randn_like(zs) Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1] JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1] print(f"intrinsic attention-map antisymmetry ||A_J v||/||J v|| = " f"{(0.5*(Jv-JTv)).norm().item()/(Jv.norm().item()+1e-9):.3f}") base = dict(T1=120, T2=20, eps=0.2, beta=0.02) print("\n[1] ATTENTION SCALE s (s=0 conservative -> larger = more non-conservative)") print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'naive(glob)':>11} {'AEP(glob)':>10} | free-eq id") for s in [0.25, 0.5, 1.0, 2.0, 4.0, 8.0]: r = measure(model, names, s, base['T1'], base['T2'], base['eps'], base['beta']) print(f"{s:6.2f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {r['gnaive']:>11.4f} {r['gaep']:>10.4f} | {r['eq_id']:.1e}") print("\n[2] NUDGE STRENGTH beta (s=2, T2=20)") print(f"{'beta':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") for beta in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2]: r = measure(model, names, 2.0, 120, 20, 0.2, beta) print(f"{beta:6.3f} | {r['naive']:>11.3f} {r['aep']:>10.3f}") print("\n[3] NUDGE STEPS T2 (s=2, beta=0.02)") print(f"{'T2':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") for T2 in [3, 5, 10, 20, 40]: r = measure(model, names, 2.0, 120, T2, 0.2, 0.02) print(f"{T2:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}") print("\n[4] FREE-PHASE STEPS T1 (s=2; AEP uses A_J at the free eq)") print(f"{'T1':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") for T1 in [20, 40, 80, 120, 200]: r = measure(model, names, 2.0, T1, 20, 0.2, 0.02) print(f"{T1:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}") print("\n[5] COST (s=2, T1=120, T2=20)") t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=False) for _ in range(3)] torch.cuda.synchronize() if dev == 'cuda' else None; tn = (time.time()-t)/3 t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=True) for _ in range(3)] torch.cuda.synchronize() if dev == 'cuda' else None; ta = (time.time()-t)/3 print(f" naive {tn*1000:.0f} ms/grad AEP {ta*1000:.0f} ms/grad overhead {ta/tn:.2f}x") if __name__ == '__main__': main()