diff options
Diffstat (limited to 'scripts/aep_attention.py')
| -rw-r--r-- | scripts/aep_attention.py | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/scripts/aep_attention.py b/scripts/aep_attention.py new file mode 100644 index 0000000..868cb05 --- /dev/null +++ b/scripts/aep_attention.py @@ -0,0 +1,157 @@ +""" +CET + AEP: does Asymmetric EP let us train a *non-conservative* attention? + +CET's energy attention is conservative by construction (grad of a scalar LogSumExp +energy -> symmetric Jacobian -> vanilla EP exact). Real transformer attention +softmax(QK^T)V with an INDEPENDENT value V is NOT the gradient of any scalar -> +non-conservative Jacobian -> vanilla EP gives a BIASED gradient. + +AEP (Scellier et al., "EP for Non-Conservative Systems", arXiv:2602.03670) adds a +nudged-phase correction -2 A_J(x*)(x - x*), A_J = (J - J^T)/2 at the free +equilibrium x*. Linearised, this turns the nudged Jacobian J into J^T -- exactly +the adjoint that vanilla EP fails to realise when J != J^T. + +We compare three parameter-gradient estimators vs ground-truth BPTT in two regimes: + cons : F = -x + b + tanh(xS) S^T (manifestly grad of a scalar -> J symmetric) [control] + noncons : F = -x + b + W_O softmax(QK^T/sqrt d)(x W_V) (real attention) [the test] + +Vector-field param-gradient (valid for non-gradient F): + dL/dtheta = < a, dF/dtheta(x*) >, a = (x_{-b} - x_{+b}) / (2 beta). +""" +import torch, torch.nn.functional as F, math +torch.manual_seed(0) +B, N, D, H = 8, 8, 32, 4 +dh = D // H +dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def mk_params(regime): + g = torch.Generator(device=dev).manual_seed(1) + s = 1.0 / math.sqrt(D) + if regime == 'noncons': + P = dict(WQ=torch.randn(D, D, generator=g, device=dev) * s, + WK=torch.randn(D, D, generator=g, device=dev) * s, + WV=torch.randn(D, D, generator=g, device=dev) * s, + WO=torch.randn(D, D, generator=g, device=dev) * s, + b=torch.zeros(D, device=dev)) + else: + P = dict(S=torch.randn(D, D, generator=g, device=dev) * s, + b=torch.zeros(D, device=dev)) + for v in P.values(): + v.requires_grad_(True) + return P + + +def _heads(t): + return t.view(B, N, H, dh).transpose(1, 2) + + +def F_noncons(x, P): # real (non-conservative) attention force + q, k, v = _heads(x @ P['WQ']), _heads(x @ P['WK']), _heads(x @ P['WV']) + A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(dh), dim=-1) + o = (A @ v).transpose(1, 2).reshape(B, N, D) @ P['WO'] + return -x + P['b'] + o + + +def F_cons(x, P): # F = -grad E, E = .5|x|^2 -<b,x> -sum logcosh(xS) + return -x + P['b'] + torch.tanh(x @ P['S']) @ P['S'].t() + + +def cost(x, R, tgt): + return 0.5 * ((x.reshape(B, -1) @ R - tgt) ** 2).sum() / B + + +def dcost(x, R, tgt): + x = x.detach().requires_grad_(True) + with torch.enable_grad(): + g, = torch.autograd.grad(cost(x, R, tgt), x) + return g + + +def relax(Ffn, P, x0, steps, eps, extra=None): + x = x0.clone() + for _ in range(steps): + with torch.no_grad(): + f = Ffn(x, P) + if extra is not None: + f = f + extra(x) + x = x + eps * f + return x.detach() + + +def AJ_apply(Ffn, P, xstar, v): # 0.5 (J v - J^T v) at xstar + with torch.enable_grad(): + fx = lambda z: Ffn(z, P) + Jv = torch.autograd.functional.jvp(fx, xstar, v)[1] + JTv = torch.autograd.functional.vjp(fx, xstar, v)[1] + return 0.5 * (Jv - JTv) + + +def ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep): + xstar = relax(Ffn, P, x0, T1, eps) + def nudged(sign): + ex = (lambda x: -2.0 * AJ_apply(Ffn, P, xstar, x - xstar)) if aep else None + fn = lambda x: Ffn(x, P) - sign * beta * dcost(x, R, tgt) + return relax(fn, None, xstar, T2, eps, extra=ex) if False else _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex) + xp, xm = nudged(+1.0), nudged(-1.0) + a = ((xm - xp) / (2.0 * beta)).detach() + xs = xstar.detach() + with torch.enable_grad(): + s = (a * Ffn(xs, P)).sum() + grads = torch.autograd.grad(s, list(P.values()), allow_unused=True) + return grads + + +def _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex): + x = xstar.clone() + for _ in range(T2): + with torch.no_grad(): + f = Ffn(x, P) - sign * beta * dcost(x, R, tgt) + if ex is not None: + f = f + ex(x) + x = x + eps * f + return x.detach() + + +def bptt_grad(Ffn, P, x0, R, tgt, T1, eps): + x = x0.clone() + for _ in range(T1): + x = x + eps * Ffn(x, P) # full graph + return torch.autograd.grad(cost(x, R, tgt), list(P.values()), allow_unused=True) + + +def cosine(ga, gb): + fa = torch.cat([g.flatten() for g in ga]) + fb = torch.cat([g.flatten() for g in gb]) + return F.cosine_similarity(fa, fb, dim=0).item() + + +def run(regime, T1=120, T2=30, eps=0.2, beta=0.02): + P = mk_params(regime) + Ffn = F_cons if regime == 'cons' else F_noncons + g = torch.Generator(device=dev).manual_seed(7) + x0 = torch.randn(B, N, D, generator=g, device=dev) * 0.1 + R = torch.randn(N * D, 16, generator=g, device=dev) / math.sqrt(N * D) + tgt = torch.randn(B, 16, generator=g, device=dev) + + xs = relax(Ffn, P, x0, T1, eps) + res = ((relax(Ffn, P, xs, 1, eps) - xs).norm() / (xs.norm() + 1e-8)).item() + v = torch.randn_like(xs) + aj = AJ_apply(Ffn, P, xs, v) + jv = torch.autograd.functional.jvp(lambda z: Ffn(z, P), xs, v)[1] + asym = (aj.norm() / (jv.norm() + 1e-8)).item() + + gb = bptt_grad(Ffn, P, x0, R, tgt, T1, eps) + gn = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=False) + ga = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=True) + + print(f"\n===== regime={regime} (residual@x*={res:.1e}) =====") + print(f" Jacobian antisymmetry ||A_J v||/||J v|| = {asym:.3f} " + f"({'~conservative' if asym < 0.05 else 'NON-conservative'})") + print(f" cosine(naive_EP, BPTT) = {cosine(gn, gb):+.4f}") + print(f" cosine( AEP , BPTT) = {cosine(ga, gb):+.4f}") + + +if __name__ == '__main__': + run('cons') + run('noncons') |
