summaryrefslogtreecommitdiff
path: root/scripts/aep_attention.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /scripts/aep_attention.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'scripts/aep_attention.py')
-rw-r--r--scripts/aep_attention.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/scripts/aep_attention.py b/scripts/aep_attention.py
new file mode 100644
index 0000000..868cb05
--- /dev/null
+++ b/scripts/aep_attention.py
@@ -0,0 +1,157 @@
+"""
+CET + AEP: does Asymmetric EP let us train a *non-conservative* attention?
+
+CET's energy attention is conservative by construction (grad of a scalar LogSumExp
+energy -> symmetric Jacobian -> vanilla EP exact). Real transformer attention
+softmax(QK^T)V with an INDEPENDENT value V is NOT the gradient of any scalar ->
+non-conservative Jacobian -> vanilla EP gives a BIASED gradient.
+
+AEP (Scellier et al., "EP for Non-Conservative Systems", arXiv:2602.03670) adds a
+nudged-phase correction -2 A_J(x*)(x - x*), A_J = (J - J^T)/2 at the free
+equilibrium x*. Linearised, this turns the nudged Jacobian J into J^T -- exactly
+the adjoint that vanilla EP fails to realise when J != J^T.
+
+We compare three parameter-gradient estimators vs ground-truth BPTT in two regimes:
+ cons : F = -x + b + tanh(xS) S^T (manifestly grad of a scalar -> J symmetric) [control]
+ noncons : F = -x + b + W_O softmax(QK^T/sqrt d)(x W_V) (real attention) [the test]
+
+Vector-field param-gradient (valid for non-gradient F):
+ dL/dtheta = < a, dF/dtheta(x*) >, a = (x_{-b} - x_{+b}) / (2 beta).
+"""
+import torch, torch.nn.functional as F, math
+torch.manual_seed(0)
+B, N, D, H = 8, 8, 32, 4
+dh = D // H
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+
+def mk_params(regime):
+ g = torch.Generator(device=dev).manual_seed(1)
+ s = 1.0 / math.sqrt(D)
+ if regime == 'noncons':
+ P = dict(WQ=torch.randn(D, D, generator=g, device=dev) * s,
+ WK=torch.randn(D, D, generator=g, device=dev) * s,
+ WV=torch.randn(D, D, generator=g, device=dev) * s,
+ WO=torch.randn(D, D, generator=g, device=dev) * s,
+ b=torch.zeros(D, device=dev))
+ else:
+ P = dict(S=torch.randn(D, D, generator=g, device=dev) * s,
+ b=torch.zeros(D, device=dev))
+ for v in P.values():
+ v.requires_grad_(True)
+ return P
+
+
+def _heads(t):
+ return t.view(B, N, H, dh).transpose(1, 2)
+
+
+def F_noncons(x, P): # real (non-conservative) attention force
+ q, k, v = _heads(x @ P['WQ']), _heads(x @ P['WK']), _heads(x @ P['WV'])
+ A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(dh), dim=-1)
+ o = (A @ v).transpose(1, 2).reshape(B, N, D) @ P['WO']
+ return -x + P['b'] + o
+
+
+def F_cons(x, P): # F = -grad E, E = .5|x|^2 -<b,x> -sum logcosh(xS)
+ return -x + P['b'] + torch.tanh(x @ P['S']) @ P['S'].t()
+
+
+def cost(x, R, tgt):
+ return 0.5 * ((x.reshape(B, -1) @ R - tgt) ** 2).sum() / B
+
+
+def dcost(x, R, tgt):
+ x = x.detach().requires_grad_(True)
+ with torch.enable_grad():
+ g, = torch.autograd.grad(cost(x, R, tgt), x)
+ return g
+
+
+def relax(Ffn, P, x0, steps, eps, extra=None):
+ x = x0.clone()
+ for _ in range(steps):
+ with torch.no_grad():
+ f = Ffn(x, P)
+ if extra is not None:
+ f = f + extra(x)
+ x = x + eps * f
+ return x.detach()
+
+
+def AJ_apply(Ffn, P, xstar, v): # 0.5 (J v - J^T v) at xstar
+ with torch.enable_grad():
+ fx = lambda z: Ffn(z, P)
+ Jv = torch.autograd.functional.jvp(fx, xstar, v)[1]
+ JTv = torch.autograd.functional.vjp(fx, xstar, v)[1]
+ return 0.5 * (Jv - JTv)
+
+
+def ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep):
+ xstar = relax(Ffn, P, x0, T1, eps)
+ def nudged(sign):
+ ex = (lambda x: -2.0 * AJ_apply(Ffn, P, xstar, x - xstar)) if aep else None
+ fn = lambda x: Ffn(x, P) - sign * beta * dcost(x, R, tgt)
+ return relax(fn, None, xstar, T2, eps, extra=ex) if False else _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex)
+ xp, xm = nudged(+1.0), nudged(-1.0)
+ a = ((xm - xp) / (2.0 * beta)).detach()
+ xs = xstar.detach()
+ with torch.enable_grad():
+ s = (a * Ffn(xs, P)).sum()
+ grads = torch.autograd.grad(s, list(P.values()), allow_unused=True)
+ return grads
+
+
+def _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex):
+ x = xstar.clone()
+ for _ in range(T2):
+ with torch.no_grad():
+ f = Ffn(x, P) - sign * beta * dcost(x, R, tgt)
+ if ex is not None:
+ f = f + ex(x)
+ x = x + eps * f
+ return x.detach()
+
+
+def bptt_grad(Ffn, P, x0, R, tgt, T1, eps):
+ x = x0.clone()
+ for _ in range(T1):
+ x = x + eps * Ffn(x, P) # full graph
+ return torch.autograd.grad(cost(x, R, tgt), list(P.values()), allow_unused=True)
+
+
+def cosine(ga, gb):
+ fa = torch.cat([g.flatten() for g in ga])
+ fb = torch.cat([g.flatten() for g in gb])
+ return F.cosine_similarity(fa, fb, dim=0).item()
+
+
+def run(regime, T1=120, T2=30, eps=0.2, beta=0.02):
+ P = mk_params(regime)
+ Ffn = F_cons if regime == 'cons' else F_noncons
+ g = torch.Generator(device=dev).manual_seed(7)
+ x0 = torch.randn(B, N, D, generator=g, device=dev) * 0.1
+ R = torch.randn(N * D, 16, generator=g, device=dev) / math.sqrt(N * D)
+ tgt = torch.randn(B, 16, generator=g, device=dev)
+
+ xs = relax(Ffn, P, x0, T1, eps)
+ res = ((relax(Ffn, P, xs, 1, eps) - xs).norm() / (xs.norm() + 1e-8)).item()
+ v = torch.randn_like(xs)
+ aj = AJ_apply(Ffn, P, xs, v)
+ jv = torch.autograd.functional.jvp(lambda z: Ffn(z, P), xs, v)[1]
+ asym = (aj.norm() / (jv.norm() + 1e-8)).item()
+
+ gb = bptt_grad(Ffn, P, x0, R, tgt, T1, eps)
+ gn = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=False)
+ ga = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=True)
+
+ print(f"\n===== regime={regime} (residual@x*={res:.1e}) =====")
+ print(f" Jacobian antisymmetry ||A_J v||/||J v|| = {asym:.3f} "
+ f"({'~conservative' if asym < 0.05 else 'NON-conservative'})")
+ print(f" cosine(naive_EP, BPTT) = {cosine(gn, gb):+.4f}")
+ print(f" cosine( AEP , BPTT) = {cosine(ga, gb):+.4f}")
+
+
+if __name__ == '__main__':
+ run('cons')
+ run('noncons')