1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
"""option 2 / MVP-A: AEP for the LM's CAUSAL attention (equilibrium reformulation), on real text.
The LM's feedforward causal attention is reformulated as a damped equilibrium block:
state z (B,T,C); input embedding x_in is clamped (boundary);
force F(z) = -(z - x_in) + s*(causal_attn(z) - c*z) [c>0: contraction -> stable fixed pt]
causal_attn = O softmax(QK^T/sqrt d , causal) V [non-conservative: independent Q/K/V/O]
Relax to z*, read out logits = z* W_head, cost = next-token cross-entropy.
Train Q/K/V/O with AEP (free + +/-beta nudged, centered, correction clipped) vs naive-EP, and
compare each to the ground-truth BPTT gradient (cosine) on the attention params -- on Shakespeare.
"""
import math, pickle, numpy as np, torch, torch.nn.functional as F
from pathlib import Path
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
DD = Path('/tmp/lt_ep/data/shakespeare_char')
vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
B, T, C, H = 16, 64, 128, 4
dh = C // H
ATTN = ('WQ', 'WK', 'WV', 'WO')
def get_batch():
data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - T - 1, (B,))
x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
return x.to(dev), y.to(dev)
P = dict(WQ=torch.randn(C, C, device=dev) / math.sqrt(C),
WK=torch.randn(C, C, device=dev) / math.sqrt(C),
WV=torch.randn(C, C, device=dev) / math.sqrt(C),
WO=torch.randn(C, C, device=dev) / math.sqrt(C))
for v in P.values():
v.requires_grad_(True)
Whead = (torch.randn(C, vocab, device=dev) / math.sqrt(C)).requires_grad_(True)
tok_emb = torch.randn(vocab, C, device=dev) * 0.02 # fixed embedding (we test Q/K/V/O grads)
pos_emb = torch.randn(T, C, device=dev) * 0.02
CMASK = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev))
DAMP, S = 1.0, 1.0
def embed(idx):
return tok_emb[idx] + pos_emb[None, :, :]
def causal_attn(z):
q = (z @ P['WQ']).view(B, T, H, dh).transpose(1, 2)
k = (z @ P['WK']).view(B, T, H, dh).transpose(1, 2)
v = (z @ P['WV']).view(B, T, H, dh).transpose(1, 2)
a = (q @ k.transpose(-2, -1)) / math.sqrt(dh)
a = a.masked_fill(~CMASK[:T, :T], float('-inf'))
a = torch.softmax(a, dim=-1)
return (a @ v).transpose(1, 2).reshape(B, T, C) @ P['WO']
def force(z, x_in):
return -(z - x_in) + S * (causal_attn(z) - DAMP * z)
def relax(z, x_in, steps, eps):
for _ in range(steps):
with torch.no_grad():
z = z + eps * force(z, x_in)
return z.detach()
def ce(z, y):
return F.cross_entropy((z @ Whead).reshape(-1, vocab), y.reshape(-1))
def aep_grad(x_in, y, T1, T2, eps, beta, aep):
zs = relax(embed_in.clone(), x_in, T1, eps)
def nudged(sign):
z = zs.clone()
for _ in range(T2):
with torch.enable_grad():
zz = z.detach().requires_grad_(True)
g, = torch.autograd.grad(ce(zz, y), zz)
with torch.no_grad():
f = force(z, x_in) - sign * beta * g
if aep:
v = (z - zs).detach()
Jv = torch.autograd.functional.jvp(causal_attn, zs, v)[1]
JTv = torch.autograd.functional.vjp(causal_attn, zs, v)[1]
corr = S * (Jv - JTv)
cn, fn = corr.norm(), f.norm() + 1e-8
corr = corr * (fn / cn) if cn > fn else corr
f = f - corr
z = z + eps * f
return z.detach()
zp, zm = nudged(+1.0), nudged(-1.0)
a = ((zm - zp) / (2 * beta)).detach()
with torch.enable_grad():
s = (a * force(zs.detach(), x_in)).sum()
return torch.autograd.grad(s, list(P.values()), allow_unused=True)
def bptt_grad(x_in, y, T1, eps):
z = embed_in.clone().requires_grad_(True)
for _ in range(T1):
z = z + eps * force(z, x_in)
return torch.autograd.grad(ce(z, y), list(P.values()), allow_unused=True)
def cos(g, gb):
cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)]
return sum(cs) / len(cs), cs
idx, y = get_batch()
x_in = embed(idx).detach()
embed_in = x_in.clone() # init state = embedding
# free-eq residual (is there a stable fixed point?)
zs = relax(embed_in.clone(), x_in, 200, 0.1)
r = (relax(zs, x_in, 1, 0.1) - zs).norm().item() / (zs.norm().item() + 1e-9)
print(f"LM causal damped-equilibrium attention on Shakespeare (B={B} T={T} C={C} H={H}, damp={DAMP})")
print(f" free-phase residual = {r:.2e} ({'stable fixed point' if r < 1e-2 else 'NOT converged'})")
gb = bptt_grad(x_in, y, 120, 0.1)
gn = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=False)
ga = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=True)
mn, csn = cos(gn, gb)
ma, csa = cos(ga, gb)
print(f"\n attention-param gradient cosine vs BPTT:")
print(f" naive-EP : mean {mn:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csn)))
print(f" AEP : mean {ma:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csa)))
|