""" CET + AEP: does Asymmetric EP let us train a *non-conservative* attention? CET's energy attention is conservative by construction (grad of a scalar LogSumExp energy -> symmetric Jacobian -> vanilla EP exact). Real transformer attention softmax(QK^T)V with an INDEPENDENT value V is NOT the gradient of any scalar -> non-conservative Jacobian -> vanilla EP gives a BIASED gradient. AEP (Scellier et al., "EP for Non-Conservative Systems", arXiv:2602.03670) adds a nudged-phase correction -2 A_J(x*)(x - x*), A_J = (J - J^T)/2 at the free equilibrium x*. Linearised, this turns the nudged Jacobian J into J^T -- exactly the adjoint that vanilla EP fails to realise when J != J^T. We compare three parameter-gradient estimators vs ground-truth BPTT in two regimes: cons : F = -x + b + tanh(xS) S^T (manifestly grad of a scalar -> J symmetric) [control] noncons : F = -x + b + W_O softmax(QK^T/sqrt d)(x W_V) (real attention) [the test] Vector-field param-gradient (valid for non-gradient F): dL/dtheta = < a, dF/dtheta(x*) >, a = (x_{-b} - x_{+b}) / (2 beta). """ import torch, torch.nn.functional as F, math torch.manual_seed(0) B, N, D, H = 8, 8, 32, 4 dh = D // H dev = 'cuda' if torch.cuda.is_available() else 'cpu' def mk_params(regime): g = torch.Generator(device=dev).manual_seed(1) s = 1.0 / math.sqrt(D) if regime == 'noncons': P = dict(WQ=torch.randn(D, D, generator=g, device=dev) * s, WK=torch.randn(D, D, generator=g, device=dev) * s, WV=torch.randn(D, D, generator=g, device=dev) * s, WO=torch.randn(D, D, generator=g, device=dev) * s, b=torch.zeros(D, device=dev)) else: P = dict(S=torch.randn(D, D, generator=g, device=dev) * s, b=torch.zeros(D, device=dev)) for v in P.values(): v.requires_grad_(True) return P def _heads(t): return t.view(B, N, H, dh).transpose(1, 2) def F_noncons(x, P): # real (non-conservative) attention force q, k, v = _heads(x @ P['WQ']), _heads(x @ P['WK']), _heads(x @ P['WV']) A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(dh), dim=-1) o = (A @ v).transpose(1, 2).reshape(B, N, D) @ P['WO'] return -x + P['b'] + o def F_cons(x, P): # F = -grad E, E = .5|x|^2 - -sum logcosh(xS) return -x + P['b'] + torch.tanh(x @ P['S']) @ P['S'].t() def cost(x, R, tgt): return 0.5 * ((x.reshape(B, -1) @ R - tgt) ** 2).sum() / B def dcost(x, R, tgt): x = x.detach().requires_grad_(True) with torch.enable_grad(): g, = torch.autograd.grad(cost(x, R, tgt), x) return g def relax(Ffn, P, x0, steps, eps, extra=None): x = x0.clone() for _ in range(steps): with torch.no_grad(): f = Ffn(x, P) if extra is not None: f = f + extra(x) x = x + eps * f return x.detach() def AJ_apply(Ffn, P, xstar, v): # 0.5 (J v - J^T v) at xstar with torch.enable_grad(): fx = lambda z: Ffn(z, P) Jv = torch.autograd.functional.jvp(fx, xstar, v)[1] JTv = torch.autograd.functional.vjp(fx, xstar, v)[1] return 0.5 * (Jv - JTv) def ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep): xstar = relax(Ffn, P, x0, T1, eps) def nudged(sign): ex = (lambda x: -2.0 * AJ_apply(Ffn, P, xstar, x - xstar)) if aep else None fn = lambda x: Ffn(x, P) - sign * beta * dcost(x, R, tgt) return relax(fn, None, xstar, T2, eps, extra=ex) if False else _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex) xp, xm = nudged(+1.0), nudged(-1.0) a = ((xm - xp) / (2.0 * beta)).detach() xs = xstar.detach() with torch.enable_grad(): s = (a * Ffn(xs, P)).sum() grads = torch.autograd.grad(s, list(P.values()), allow_unused=True) return grads def _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex): x = xstar.clone() for _ in range(T2): with torch.no_grad(): f = Ffn(x, P) - sign * beta * dcost(x, R, tgt) if ex is not None: f = f + ex(x) x = x + eps * f return x.detach() def bptt_grad(Ffn, P, x0, R, tgt, T1, eps): x = x0.clone() for _ in range(T1): x = x + eps * Ffn(x, P) # full graph return torch.autograd.grad(cost(x, R, tgt), list(P.values()), allow_unused=True) def cosine(ga, gb): fa = torch.cat([g.flatten() for g in ga]) fb = torch.cat([g.flatten() for g in gb]) return F.cosine_similarity(fa, fb, dim=0).item() def run(regime, T1=120, T2=30, eps=0.2, beta=0.02): P = mk_params(regime) Ffn = F_cons if regime == 'cons' else F_noncons g = torch.Generator(device=dev).manual_seed(7) x0 = torch.randn(B, N, D, generator=g, device=dev) * 0.1 R = torch.randn(N * D, 16, generator=g, device=dev) / math.sqrt(N * D) tgt = torch.randn(B, 16, generator=g, device=dev) xs = relax(Ffn, P, x0, T1, eps) res = ((relax(Ffn, P, xs, 1, eps) - xs).norm() / (xs.norm() + 1e-8)).item() v = torch.randn_like(xs) aj = AJ_apply(Ffn, P, xs, v) jv = torch.autograd.functional.jvp(lambda z: Ffn(z, P), xs, v)[1] asym = (aj.norm() / (jv.norm() + 1e-8)).item() gb = bptt_grad(Ffn, P, x0, R, tgt, T1, eps) gn = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=False) ga = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=True) print(f"\n===== regime={regime} (residual@x*={res:.1e}) =====") print(f" Jacobian antisymmetry ||A_J v||/||J v|| = {asym:.3f} " f"({'~conservative' if asym < 0.05 else 'NON-conservative'})") print(f" cosine(naive_EP, BPTT) = {cosine(gn, gb):+.4f}") print(f" cosine( AEP , BPTT) = {cosine(ga, gb):+.4f}") if __name__ == '__main__': run('cons') run('noncons')