1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
"""option 2 / H1: replace the FFN-FA (the abandonment reason) with EP's Hopfield-memory E^mem.
Compare attention-FFN gradient quality vs true BP:
FA-FFN : the project's LocalMLP(method='fa') -> expect FA's signature failure on the upstream layer
EP-FFN : Hopfield memory E_mem = 0.5a||h-x||^2 - sum relu(hW)^2 (CONSERVATIVE -> plain EP, no AEP)
trained by centered energy-EP from free/nudged per-token equilibria.
"""
import math, pickle, numpy as np, torch, torch.nn.functional as F
from pathlib import Path
from model_local import LocalMLP, LocalGPTConfig
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
DD = Path('/tmp/lt_ep/data/shakespeare_char')
vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
B, T, C, Mm = 16, 64, 128, 256
def get_batch():
data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - T - 1, (B,))
x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
return x.to(dev), y.to(dev)
idx, y = get_batch()
tok = torch.randn(vocab, C, device=dev) * 0.02
pos = torch.randn(T, C, device=dev) * 0.02
XIN = (tok[idx] + pos[None]).detach()
Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C)
# ---------- FA-FFN (project's own LocalMLP) ----------
def make_mlp(method):
cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_embd=C, method=method, dropout=0.0, bias=False)
return LocalMLP(cfg).to(dev)
bp = make_mlp('bp')
fa = make_mlp('fa')
fa.fc.weight.data.copy_(bp.fc.weight.data); fa.proj.weight.data.copy_(bp.proj.weight.data)
def mlp_grads(m):
for p in m.parameters():
p.grad = None
out = m(XIN)
F.cross_entropy(((XIN + out) @ Whead).reshape(-1, vocab), y.reshape(-1)).backward()
return [m.fc.weight.grad, m.proj.weight.grad]
gb = mlp_grads(bp); gfa = mlp_grads(fa)
cfa = [F.cosine_similarity(a.flatten(), b.flatten(), 0).item() for a, b in zip(gfa, gb)]
# ---------- EP-FFN (Hopfield memory E_mem, conservative) ----------
ALPHA = 2.0
W = (torch.randn(C, Mm, device=dev) * 0.3 / math.sqrt(C)).requires_grad_(True)
def E_mem(h, x):
return 0.5 * ALPHA * ((h - x) ** 2).sum() - (F.relu(h @ W) ** 2).sum()
def force(h, x, cg=False):
with torch.enable_grad():
hr = h if h.requires_grad else h.detach().requires_grad_(True)
g, = torch.autograd.grad(E_mem(hr, x), hr, create_graph=cg)
return -g
def relax(h, x, steps, eps):
for _ in range(steps):
with torch.enable_grad():
f = force(h, x).detach()
with torch.no_grad():
h = h + eps * f
return h.detach()
def ce(h):
return F.cross_entropy((h @ Whead).reshape(-1, vocab), y.reshape(-1))
def ep_grad(x, T1, T2, eps, beta):
hs = relax(x.clone(), x, T1, eps)
def nudge(sign):
h = hs.clone()
for _ in range(T2):
with torch.enable_grad():
hh = h.detach().requires_grad_(True)
g, = torch.autograd.grad(ce(hh), hh)
with torch.no_grad():
h = h + eps * (force(h, x).detach() - sign * beta * g)
return h.detach()
hp, hm = nudge(+1), nudge(-1)
with torch.enable_grad():
gp, = torch.autograd.grad(E_mem(hp, x), W)
gm, = torch.autograd.grad(E_mem(hm, x), W)
return ((gp - gm) / (2 * beta)).detach(), hs
def bptt_grad(x, T1, eps):
h = x.clone().requires_grad_(True)
for _ in range(T1):
h = h + eps * force(h, x, cg=True)
return torch.autograd.grad(ce(h), W)[0]
hs = relax(XIN.clone(), XIN, 200, 0.1)
r = (relax(hs, XIN, 1, 0.1) - hs).norm().item() / (hs.norm().item() + 1e-9)
gep, _ = ep_grad(XIN, 120, 20, 0.1, 0.02)
gbp_ep = bptt_grad(XIN, 120, 0.1)
cep = F.cosine_similarity(gep.flatten(), gbp_ep.flatten(), 0).item()
print("H1: FFN gradient quality vs true BP, on Shakespeare LM block")
print(f" FA-FFN (LocalMLP, method=fa) : fc={cfa[0]:+.3f} proj={cfa[1]:+.3f} mean={sum(cfa)/2:+.3f}")
print(f" EP-FFN (Hopfield E_mem) : W_mem cosine = {cep:+.3f} (free-phase residual {r:.1e})")
print(f"\n -> FA fails on the upstream FFN layer (fc); EP-memory gives a faithful local gradient.")
|