summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_ffn.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/lt_ep_ffn.py')
-rw-r--r--ep_run/lt_ep_ffn.py119
1 files changed, 119 insertions, 0 deletions
diff --git a/ep_run/lt_ep_ffn.py b/ep_run/lt_ep_ffn.py
new file mode 100644
index 0000000..8c060c7
--- /dev/null
+++ b/ep_run/lt_ep_ffn.py
@@ -0,0 +1,119 @@
+"""option 2 / H1: replace the FFN-FA (the abandonment reason) with EP's Hopfield-memory E^mem.
+Compare attention-FFN gradient quality vs true BP:
+ FA-FFN : the project's LocalMLP(method='fa') -> expect FA's signature failure on the upstream layer
+ EP-FFN : Hopfield memory E_mem = 0.5a||h-x||^2 - sum relu(hW)^2 (CONSERVATIVE -> plain EP, no AEP)
+ trained by centered energy-EP from free/nudged per-token equilibria.
+"""
+import math, pickle, numpy as np, torch, torch.nn.functional as F
+from pathlib import Path
+from model_local import LocalMLP, LocalGPTConfig
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+DD = Path('/tmp/lt_ep/data/shakespeare_char')
+vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
+B, T, C, Mm = 16, 64, 128, 256
+
+
+def get_batch():
+ data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r')
+ ix = torch.randint(len(data) - T - 1, (B,))
+ x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
+ return x.to(dev), y.to(dev)
+
+
+idx, y = get_batch()
+tok = torch.randn(vocab, C, device=dev) * 0.02
+pos = torch.randn(T, C, device=dev) * 0.02
+XIN = (tok[idx] + pos[None]).detach()
+Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C)
+
+
+# ---------- FA-FFN (project's own LocalMLP) ----------
+def make_mlp(method):
+ cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_embd=C, method=method, dropout=0.0, bias=False)
+ return LocalMLP(cfg).to(dev)
+
+
+bp = make_mlp('bp')
+fa = make_mlp('fa')
+fa.fc.weight.data.copy_(bp.fc.weight.data); fa.proj.weight.data.copy_(bp.proj.weight.data)
+
+
+def mlp_grads(m):
+ for p in m.parameters():
+ p.grad = None
+ out = m(XIN)
+ F.cross_entropy(((XIN + out) @ Whead).reshape(-1, vocab), y.reshape(-1)).backward()
+ return [m.fc.weight.grad, m.proj.weight.grad]
+
+
+gb = mlp_grads(bp); gfa = mlp_grads(fa)
+cfa = [F.cosine_similarity(a.flatten(), b.flatten(), 0).item() for a, b in zip(gfa, gb)]
+
+
+# ---------- EP-FFN (Hopfield memory E_mem, conservative) ----------
+ALPHA = 2.0
+W = (torch.randn(C, Mm, device=dev) * 0.3 / math.sqrt(C)).requires_grad_(True)
+
+
+def E_mem(h, x):
+ return 0.5 * ALPHA * ((h - x) ** 2).sum() - (F.relu(h @ W) ** 2).sum()
+
+
+def force(h, x, cg=False):
+ with torch.enable_grad():
+ hr = h if h.requires_grad else h.detach().requires_grad_(True)
+ g, = torch.autograd.grad(E_mem(hr, x), hr, create_graph=cg)
+ return -g
+
+
+def relax(h, x, steps, eps):
+ for _ in range(steps):
+ with torch.enable_grad():
+ f = force(h, x).detach()
+ with torch.no_grad():
+ h = h + eps * f
+ return h.detach()
+
+
+def ce(h):
+ return F.cross_entropy((h @ Whead).reshape(-1, vocab), y.reshape(-1))
+
+
+def ep_grad(x, T1, T2, eps, beta):
+ hs = relax(x.clone(), x, T1, eps)
+ def nudge(sign):
+ h = hs.clone()
+ for _ in range(T2):
+ with torch.enable_grad():
+ hh = h.detach().requires_grad_(True)
+ g, = torch.autograd.grad(ce(hh), hh)
+ with torch.no_grad():
+ h = h + eps * (force(h, x).detach() - sign * beta * g)
+ return h.detach()
+ hp, hm = nudge(+1), nudge(-1)
+ with torch.enable_grad():
+ gp, = torch.autograd.grad(E_mem(hp, x), W)
+ gm, = torch.autograd.grad(E_mem(hm, x), W)
+ return ((gp - gm) / (2 * beta)).detach(), hs
+
+
+def bptt_grad(x, T1, eps):
+ h = x.clone().requires_grad_(True)
+ for _ in range(T1):
+ h = h + eps * force(h, x, cg=True)
+ return torch.autograd.grad(ce(h), W)[0]
+
+
+hs = relax(XIN.clone(), XIN, 200, 0.1)
+r = (relax(hs, XIN, 1, 0.1) - hs).norm().item() / (hs.norm().item() + 1e-9)
+gep, _ = ep_grad(XIN, 120, 20, 0.1, 0.02)
+gbp_ep = bptt_grad(XIN, 120, 0.1)
+cep = F.cosine_similarity(gep.flatten(), gbp_ep.flatten(), 0).item()
+
+print("H1: FFN gradient quality vs true BP, on Shakespeare LM block")
+print(f" FA-FFN (LocalMLP, method=fa) : fc={cfa[0]:+.3f} proj={cfa[1]:+.3f} mean={sum(cfa)/2:+.3f}")
+print(f" EP-FFN (Hopfield E_mem) : W_mem cosine = {cep:+.3f} (free-phase residual {r:.1e})")
+print(f"\n -> FA fails on the upstream FFN layer (fc); EP-memory gives a faithful local gradient.")