1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
"""
CET + AEP: does Asymmetric EP let us train a *non-conservative* attention?
CET's energy attention is conservative by construction (grad of a scalar LogSumExp
energy -> symmetric Jacobian -> vanilla EP exact). Real transformer attention
softmax(QK^T)V with an INDEPENDENT value V is NOT the gradient of any scalar ->
non-conservative Jacobian -> vanilla EP gives a BIASED gradient.
AEP (Scellier et al., "EP for Non-Conservative Systems", arXiv:2602.03670) adds a
nudged-phase correction -2 A_J(x*)(x - x*), A_J = (J - J^T)/2 at the free
equilibrium x*. Linearised, this turns the nudged Jacobian J into J^T -- exactly
the adjoint that vanilla EP fails to realise when J != J^T.
We compare three parameter-gradient estimators vs ground-truth BPTT in two regimes:
cons : F = -x + b + tanh(xS) S^T (manifestly grad of a scalar -> J symmetric) [control]
noncons : F = -x + b + W_O softmax(QK^T/sqrt d)(x W_V) (real attention) [the test]
Vector-field param-gradient (valid for non-gradient F):
dL/dtheta = < a, dF/dtheta(x*) >, a = (x_{-b} - x_{+b}) / (2 beta).
"""
import torch, torch.nn.functional as F, math
torch.manual_seed(0)
B, N, D, H = 8, 8, 32, 4
dh = D // H
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
def mk_params(regime):
g = torch.Generator(device=dev).manual_seed(1)
s = 1.0 / math.sqrt(D)
if regime == 'noncons':
P = dict(WQ=torch.randn(D, D, generator=g, device=dev) * s,
WK=torch.randn(D, D, generator=g, device=dev) * s,
WV=torch.randn(D, D, generator=g, device=dev) * s,
WO=torch.randn(D, D, generator=g, device=dev) * s,
b=torch.zeros(D, device=dev))
else:
P = dict(S=torch.randn(D, D, generator=g, device=dev) * s,
b=torch.zeros(D, device=dev))
for v in P.values():
v.requires_grad_(True)
return P
def _heads(t):
return t.view(B, N, H, dh).transpose(1, 2)
def F_noncons(x, P): # real (non-conservative) attention force
q, k, v = _heads(x @ P['WQ']), _heads(x @ P['WK']), _heads(x @ P['WV'])
A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(dh), dim=-1)
o = (A @ v).transpose(1, 2).reshape(B, N, D) @ P['WO']
return -x + P['b'] + o
def F_cons(x, P): # F = -grad E, E = .5|x|^2 -<b,x> -sum logcosh(xS)
return -x + P['b'] + torch.tanh(x @ P['S']) @ P['S'].t()
def cost(x, R, tgt):
return 0.5 * ((x.reshape(B, -1) @ R - tgt) ** 2).sum() / B
def dcost(x, R, tgt):
x = x.detach().requires_grad_(True)
with torch.enable_grad():
g, = torch.autograd.grad(cost(x, R, tgt), x)
return g
def relax(Ffn, P, x0, steps, eps, extra=None):
x = x0.clone()
for _ in range(steps):
with torch.no_grad():
f = Ffn(x, P)
if extra is not None:
f = f + extra(x)
x = x + eps * f
return x.detach()
def AJ_apply(Ffn, P, xstar, v): # 0.5 (J v - J^T v) at xstar
with torch.enable_grad():
fx = lambda z: Ffn(z, P)
Jv = torch.autograd.functional.jvp(fx, xstar, v)[1]
JTv = torch.autograd.functional.vjp(fx, xstar, v)[1]
return 0.5 * (Jv - JTv)
def ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep):
xstar = relax(Ffn, P, x0, T1, eps)
def nudged(sign):
ex = (lambda x: -2.0 * AJ_apply(Ffn, P, xstar, x - xstar)) if aep else None
fn = lambda x: Ffn(x, P) - sign * beta * dcost(x, R, tgt)
return relax(fn, None, xstar, T2, eps, extra=ex) if False else _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex)
xp, xm = nudged(+1.0), nudged(-1.0)
a = ((xm - xp) / (2.0 * beta)).detach()
xs = xstar.detach()
with torch.enable_grad():
s = (a * Ffn(xs, P)).sum()
grads = torch.autograd.grad(s, list(P.values()), allow_unused=True)
return grads
def _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex):
x = xstar.clone()
for _ in range(T2):
with torch.no_grad():
f = Ffn(x, P) - sign * beta * dcost(x, R, tgt)
if ex is not None:
f = f + ex(x)
x = x + eps * f
return x.detach()
def bptt_grad(Ffn, P, x0, R, tgt, T1, eps):
x = x0.clone()
for _ in range(T1):
x = x + eps * Ffn(x, P) # full graph
return torch.autograd.grad(cost(x, R, tgt), list(P.values()), allow_unused=True)
def cosine(ga, gb):
fa = torch.cat([g.flatten() for g in ga])
fb = torch.cat([g.flatten() for g in gb])
return F.cosine_similarity(fa, fb, dim=0).item()
def run(regime, T1=120, T2=30, eps=0.2, beta=0.02):
P = mk_params(regime)
Ffn = F_cons if regime == 'cons' else F_noncons
g = torch.Generator(device=dev).manual_seed(7)
x0 = torch.randn(B, N, D, generator=g, device=dev) * 0.1
R = torch.randn(N * D, 16, generator=g, device=dev) / math.sqrt(N * D)
tgt = torch.randn(B, 16, generator=g, device=dev)
xs = relax(Ffn, P, x0, T1, eps)
res = ((relax(Ffn, P, xs, 1, eps) - xs).norm() / (xs.norm() + 1e-8)).item()
v = torch.randn_like(xs)
aj = AJ_apply(Ffn, P, xs, v)
jv = torch.autograd.functional.jvp(lambda z: Ffn(z, P), xs, v)[1]
asym = (aj.norm() / (jv.norm() + 1e-8)).item()
gb = bptt_grad(Ffn, P, x0, R, tgt, T1, eps)
gn = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=False)
ga = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=True)
print(f"\n===== regime={regime} (residual@x*={res:.1e}) =====")
print(f" Jacobian antisymmetry ||A_J v||/||J v|| = {asym:.3f} "
f"({'~conservative' if asym < 0.05 else 'NON-conservative'})")
print(f" cosine(naive_EP, BPTT) = {cosine(gn, gb):+.4f}")
print(f" cosine( AEP , BPTT) = {cosine(ga, gb):+.4f}")
if __name__ == '__main__':
run('cons')
run('noncons')
|