"""option 2 / H1: replace the FFN-FA (the abandonment reason) with EP's Hopfield-memory E^mem. Compare attention-FFN gradient quality vs true BP: FA-FFN : the project's LocalMLP(method='fa') -> expect FA's signature failure on the upstream layer EP-FFN : Hopfield memory E_mem = 0.5a||h-x||^2 - sum relu(hW)^2 (CONSERVATIVE -> plain EP, no AEP) trained by centered energy-EP from free/nudged per-token equilibria. """ import math, pickle, numpy as np, torch, torch.nn.functional as F from pathlib import Path from model_local import LocalMLP, LocalGPTConfig dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) DD = Path('/tmp/lt_ep/data/shakespeare_char') vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] B, T, C, Mm = 16, 64, 128, 256 def get_batch(): data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') ix = torch.randint(len(data) - T - 1, (B,)) x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) return x.to(dev), y.to(dev) idx, y = get_batch() tok = torch.randn(vocab, C, device=dev) * 0.02 pos = torch.randn(T, C, device=dev) * 0.02 XIN = (tok[idx] + pos[None]).detach() Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C) # ---------- FA-FFN (project's own LocalMLP) ---------- def make_mlp(method): cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_embd=C, method=method, dropout=0.0, bias=False) return LocalMLP(cfg).to(dev) bp = make_mlp('bp') fa = make_mlp('fa') fa.fc.weight.data.copy_(bp.fc.weight.data); fa.proj.weight.data.copy_(bp.proj.weight.data) def mlp_grads(m): for p in m.parameters(): p.grad = None out = m(XIN) F.cross_entropy(((XIN + out) @ Whead).reshape(-1, vocab), y.reshape(-1)).backward() return [m.fc.weight.grad, m.proj.weight.grad] gb = mlp_grads(bp); gfa = mlp_grads(fa) cfa = [F.cosine_similarity(a.flatten(), b.flatten(), 0).item() for a, b in zip(gfa, gb)] # ---------- EP-FFN (Hopfield memory E_mem, conservative) ---------- ALPHA = 2.0 W = (torch.randn(C, Mm, device=dev) * 0.3 / math.sqrt(C)).requires_grad_(True) def E_mem(h, x): return 0.5 * ALPHA * ((h - x) ** 2).sum() - (F.relu(h @ W) ** 2).sum() def force(h, x, cg=False): with torch.enable_grad(): hr = h if h.requires_grad else h.detach().requires_grad_(True) g, = torch.autograd.grad(E_mem(hr, x), hr, create_graph=cg) return -g def relax(h, x, steps, eps): for _ in range(steps): with torch.enable_grad(): f = force(h, x).detach() with torch.no_grad(): h = h + eps * f return h.detach() def ce(h): return F.cross_entropy((h @ Whead).reshape(-1, vocab), y.reshape(-1)) def ep_grad(x, T1, T2, eps, beta): hs = relax(x.clone(), x, T1, eps) def nudge(sign): h = hs.clone() for _ in range(T2): with torch.enable_grad(): hh = h.detach().requires_grad_(True) g, = torch.autograd.grad(ce(hh), hh) with torch.no_grad(): h = h + eps * (force(h, x).detach() - sign * beta * g) return h.detach() hp, hm = nudge(+1), nudge(-1) with torch.enable_grad(): gp, = torch.autograd.grad(E_mem(hp, x), W) gm, = torch.autograd.grad(E_mem(hm, x), W) return ((gp - gm) / (2 * beta)).detach(), hs def bptt_grad(x, T1, eps): h = x.clone().requires_grad_(True) for _ in range(T1): h = h + eps * force(h, x, cg=True) return torch.autograd.grad(ce(h), W)[0] hs = relax(XIN.clone(), XIN, 200, 0.1) r = (relax(hs, XIN, 1, 0.1) - hs).norm().item() / (hs.norm().item() + 1e-9) gep, _ = ep_grad(XIN, 120, 20, 0.1, 0.02) gbp_ep = bptt_grad(XIN, 120, 0.1) cep = F.cosine_similarity(gep.flatten(), gbp_ep.flatten(), 0).item() print("H1: FFN gradient quality vs true BP, on Shakespeare LM block") print(f" FA-FFN (LocalMLP, method=fa) : fc={cfa[0]:+.3f} proj={cfa[1]:+.3f} mean={sum(cfa)/2:+.3f}") print(f" EP-FFN (Hopfield E_mem) : W_mem cosine = {cep:+.3f} (free-phase residual {r:.1e})") print(f"\n -> FA fails on the upstream FFN layer (fc); EP-memory gives a faithful local gradient.")