diff options
Diffstat (limited to 'ep_run/lt_ep_ffn.py')
| -rw-r--r-- | ep_run/lt_ep_ffn.py | 119 |
1 files changed, 119 insertions, 0 deletions
diff --git a/ep_run/lt_ep_ffn.py b/ep_run/lt_ep_ffn.py new file mode 100644 index 0000000..8c060c7 --- /dev/null +++ b/ep_run/lt_ep_ffn.py @@ -0,0 +1,119 @@ +"""option 2 / H1: replace the FFN-FA (the abandonment reason) with EP's Hopfield-memory E^mem. +Compare attention-FFN gradient quality vs true BP: + FA-FFN : the project's LocalMLP(method='fa') -> expect FA's signature failure on the upstream layer + EP-FFN : Hopfield memory E_mem = 0.5a||h-x||^2 - sum relu(hW)^2 (CONSERVATIVE -> plain EP, no AEP) + trained by centered energy-EP from free/nudged per-token equilibria. +""" +import math, pickle, numpy as np, torch, torch.nn.functional as F +from pathlib import Path +from model_local import LocalMLP, LocalGPTConfig + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +DD = Path('/tmp/lt_ep/data/shakespeare_char') +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] +B, T, C, Mm = 16, 64, 128, 256 + + +def get_batch(): + data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +idx, y = get_batch() +tok = torch.randn(vocab, C, device=dev) * 0.02 +pos = torch.randn(T, C, device=dev) * 0.02 +XIN = (tok[idx] + pos[None]).detach() +Whead = torch.randn(C, vocab, device=dev) / math.sqrt(C) + + +# ---------- FA-FFN (project's own LocalMLP) ---------- +def make_mlp(method): + cfg = LocalGPTConfig(block_size=T, vocab_size=vocab, n_embd=C, method=method, dropout=0.0, bias=False) + return LocalMLP(cfg).to(dev) + + +bp = make_mlp('bp') +fa = make_mlp('fa') +fa.fc.weight.data.copy_(bp.fc.weight.data); fa.proj.weight.data.copy_(bp.proj.weight.data) + + +def mlp_grads(m): + for p in m.parameters(): + p.grad = None + out = m(XIN) + F.cross_entropy(((XIN + out) @ Whead).reshape(-1, vocab), y.reshape(-1)).backward() + return [m.fc.weight.grad, m.proj.weight.grad] + + +gb = mlp_grads(bp); gfa = mlp_grads(fa) +cfa = [F.cosine_similarity(a.flatten(), b.flatten(), 0).item() for a, b in zip(gfa, gb)] + + +# ---------- EP-FFN (Hopfield memory E_mem, conservative) ---------- +ALPHA = 2.0 +W = (torch.randn(C, Mm, device=dev) * 0.3 / math.sqrt(C)).requires_grad_(True) + + +def E_mem(h, x): + return 0.5 * ALPHA * ((h - x) ** 2).sum() - (F.relu(h @ W) ** 2).sum() + + +def force(h, x, cg=False): + with torch.enable_grad(): + hr = h if h.requires_grad else h.detach().requires_grad_(True) + g, = torch.autograd.grad(E_mem(hr, x), hr, create_graph=cg) + return -g + + +def relax(h, x, steps, eps): + for _ in range(steps): + with torch.enable_grad(): + f = force(h, x).detach() + with torch.no_grad(): + h = h + eps * f + return h.detach() + + +def ce(h): + return F.cross_entropy((h @ Whead).reshape(-1, vocab), y.reshape(-1)) + + +def ep_grad(x, T1, T2, eps, beta): + hs = relax(x.clone(), x, T1, eps) + def nudge(sign): + h = hs.clone() + for _ in range(T2): + with torch.enable_grad(): + hh = h.detach().requires_grad_(True) + g, = torch.autograd.grad(ce(hh), hh) + with torch.no_grad(): + h = h + eps * (force(h, x).detach() - sign * beta * g) + return h.detach() + hp, hm = nudge(+1), nudge(-1) + with torch.enable_grad(): + gp, = torch.autograd.grad(E_mem(hp, x), W) + gm, = torch.autograd.grad(E_mem(hm, x), W) + return ((gp - gm) / (2 * beta)).detach(), hs + + +def bptt_grad(x, T1, eps): + h = x.clone().requires_grad_(True) + for _ in range(T1): + h = h + eps * force(h, x, cg=True) + return torch.autograd.grad(ce(h), W)[0] + + +hs = relax(XIN.clone(), XIN, 200, 0.1) +r = (relax(hs, XIN, 1, 0.1) - hs).norm().item() / (hs.norm().item() + 1e-9) +gep, _ = ep_grad(XIN, 120, 20, 0.1, 0.02) +gbp_ep = bptt_grad(XIN, 120, 0.1) +cep = F.cosine_similarity(gep.flatten(), gbp_ep.flatten(), 0).item() + +print("H1: FFN gradient quality vs true BP, on Shakespeare LM block") +print(f" FA-FFN (LocalMLP, method=fa) : fc={cfa[0]:+.3f} proj={cfa[1]:+.3f} mean={sum(cfa)/2:+.3f}") +print(f" EP-FFN (Hopfield E_mem) : W_mem cosine = {cep:+.3f} (free-phase residual {r:.1e})") +print(f"\n -> FA fails on the upstream FFN layer (fc); EP-memory gives a faithful local gradient.") |
