"""option 2 / MVP-A: AEP for the LM's CAUSAL attention (equilibrium reformulation), on real text. The LM's feedforward causal attention is reformulated as a damped equilibrium block: state z (B,T,C); input embedding x_in is clamped (boundary); force F(z) = -(z - x_in) + s*(causal_attn(z) - c*z) [c>0: contraction -> stable fixed pt] causal_attn = O softmax(QK^T/sqrt d , causal) V [non-conservative: independent Q/K/V/O] Relax to z*, read out logits = z* W_head, cost = next-token cross-entropy. Train Q/K/V/O with AEP (free + +/-beta nudged, centered, correction clipped) vs naive-EP, and compare each to the ground-truth BPTT gradient (cosine) on the attention params -- on Shakespeare. """ import math, pickle, numpy as np, torch, torch.nn.functional as F from pathlib import Path dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) DD = Path('/tmp/lt_ep/data/shakespeare_char') vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] B, T, C, H = 16, 64, 128, 4 dh = C // H ATTN = ('WQ', 'WK', 'WV', 'WO') def get_batch(): data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') ix = torch.randint(len(data) - T - 1, (B,)) x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) return x.to(dev), y.to(dev) P = dict(WQ=torch.randn(C, C, device=dev) / math.sqrt(C), WK=torch.randn(C, C, device=dev) / math.sqrt(C), WV=torch.randn(C, C, device=dev) / math.sqrt(C), WO=torch.randn(C, C, device=dev) / math.sqrt(C)) for v in P.values(): v.requires_grad_(True) Whead = (torch.randn(C, vocab, device=dev) / math.sqrt(C)).requires_grad_(True) tok_emb = torch.randn(vocab, C, device=dev) * 0.02 # fixed embedding (we test Q/K/V/O grads) pos_emb = torch.randn(T, C, device=dev) * 0.02 CMASK = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev)) DAMP, S = 1.0, 1.0 def embed(idx): return tok_emb[idx] + pos_emb[None, :, :] def causal_attn(z): q = (z @ P['WQ']).view(B, T, H, dh).transpose(1, 2) k = (z @ P['WK']).view(B, T, H, dh).transpose(1, 2) v = (z @ P['WV']).view(B, T, H, dh).transpose(1, 2) a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) a = a.masked_fill(~CMASK[:T, :T], float('-inf')) a = torch.softmax(a, dim=-1) return (a @ v).transpose(1, 2).reshape(B, T, C) @ P['WO'] def force(z, x_in): return -(z - x_in) + S * (causal_attn(z) - DAMP * z) def relax(z, x_in, steps, eps): for _ in range(steps): with torch.no_grad(): z = z + eps * force(z, x_in) return z.detach() def ce(z, y): return F.cross_entropy((z @ Whead).reshape(-1, vocab), y.reshape(-1)) def aep_grad(x_in, y, T1, T2, eps, beta, aep): zs = relax(embed_in.clone(), x_in, T1, eps) def nudged(sign): z = zs.clone() for _ in range(T2): with torch.enable_grad(): zz = z.detach().requires_grad_(True) g, = torch.autograd.grad(ce(zz, y), zz) with torch.no_grad(): f = force(z, x_in) - sign * beta * g if aep: v = (z - zs).detach() Jv = torch.autograd.functional.jvp(causal_attn, zs, v)[1] JTv = torch.autograd.functional.vjp(causal_attn, zs, v)[1] corr = S * (Jv - JTv) cn, fn = corr.norm(), f.norm() + 1e-8 corr = corr * (fn / cn) if cn > fn else corr f = f - corr z = z + eps * f return z.detach() zp, zm = nudged(+1.0), nudged(-1.0) a = ((zm - zp) / (2 * beta)).detach() with torch.enable_grad(): s = (a * force(zs.detach(), x_in)).sum() return torch.autograd.grad(s, list(P.values()), allow_unused=True) def bptt_grad(x_in, y, T1, eps): z = embed_in.clone().requires_grad_(True) for _ in range(T1): z = z + eps * force(z, x_in) return torch.autograd.grad(ce(z, y), list(P.values()), allow_unused=True) def cos(g, gb): cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)] return sum(cs) / len(cs), cs idx, y = get_batch() x_in = embed(idx).detach() embed_in = x_in.clone() # init state = embedding # free-eq residual (is there a stable fixed point?) zs = relax(embed_in.clone(), x_in, 200, 0.1) r = (relax(zs, x_in, 1, 0.1) - zs).norm().item() / (zs.norm().item() + 1e-9) print(f"LM causal damped-equilibrium attention on Shakespeare (B={B} T={T} C={C} H={H}, damp={DAMP})") print(f" free-phase residual = {r:.2e} ({'stable fixed point' if r < 1e-2 else 'NOT converged'})") gb = bptt_grad(x_in, y, 120, 0.1) gn = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=False) ga = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=True) mn, csn = cos(gn, gb) ma, csa = cos(ga, gb) print(f"\n attention-param gradient cosine vs BPTT:") print(f" naive-EP : mean {mn:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csn))) print(f" AEP : mean {ma:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csa)))