summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/lt_ep_attention.py')
-rw-r--r--ep_run/lt_ep_attention.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/ep_run/lt_ep_attention.py b/ep_run/lt_ep_attention.py
new file mode 100644
index 0000000..c411de6
--- /dev/null
+++ b/ep_run/lt_ep_attention.py
@@ -0,0 +1,129 @@
+"""option 2 / MVP-A: AEP for the LM's CAUSAL attention (equilibrium reformulation), on real text.
+
+The LM's feedforward causal attention is reformulated as a damped equilibrium block:
+ state z (B,T,C); input embedding x_in is clamped (boundary);
+ force F(z) = -(z - x_in) + s*(causal_attn(z) - c*z) [c>0: contraction -> stable fixed pt]
+ causal_attn = O softmax(QK^T/sqrt d , causal) V [non-conservative: independent Q/K/V/O]
+Relax to z*, read out logits = z* W_head, cost = next-token cross-entropy.
+Train Q/K/V/O with AEP (free + +/-beta nudged, centered, correction clipped) vs naive-EP, and
+compare each to the ground-truth BPTT gradient (cosine) on the attention params -- on Shakespeare.
+"""
+import math, pickle, numpy as np, torch, torch.nn.functional as F
+from pathlib import Path
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+DD = Path('/tmp/lt_ep/data/shakespeare_char')
+vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
+B, T, C, H = 16, 64, 128, 4
+dh = C // H
+ATTN = ('WQ', 'WK', 'WV', 'WO')
+
+
+def get_batch():
+ data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r')
+ ix = torch.randint(len(data) - T - 1, (B,))
+ x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
+ return x.to(dev), y.to(dev)
+
+
+P = dict(WQ=torch.randn(C, C, device=dev) / math.sqrt(C),
+ WK=torch.randn(C, C, device=dev) / math.sqrt(C),
+ WV=torch.randn(C, C, device=dev) / math.sqrt(C),
+ WO=torch.randn(C, C, device=dev) / math.sqrt(C))
+for v in P.values():
+ v.requires_grad_(True)
+Whead = (torch.randn(C, vocab, device=dev) / math.sqrt(C)).requires_grad_(True)
+tok_emb = torch.randn(vocab, C, device=dev) * 0.02 # fixed embedding (we test Q/K/V/O grads)
+pos_emb = torch.randn(T, C, device=dev) * 0.02
+CMASK = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev))
+DAMP, S = 1.0, 1.0
+
+
+def embed(idx):
+ return tok_emb[idx] + pos_emb[None, :, :]
+
+
+def causal_attn(z):
+ q = (z @ P['WQ']).view(B, T, H, dh).transpose(1, 2)
+ k = (z @ P['WK']).view(B, T, H, dh).transpose(1, 2)
+ v = (z @ P['WV']).view(B, T, H, dh).transpose(1, 2)
+ a = (q @ k.transpose(-2, -1)) / math.sqrt(dh)
+ a = a.masked_fill(~CMASK[:T, :T], float('-inf'))
+ a = torch.softmax(a, dim=-1)
+ return (a @ v).transpose(1, 2).reshape(B, T, C) @ P['WO']
+
+
+def force(z, x_in):
+ return -(z - x_in) + S * (causal_attn(z) - DAMP * z)
+
+
+def relax(z, x_in, steps, eps):
+ for _ in range(steps):
+ with torch.no_grad():
+ z = z + eps * force(z, x_in)
+ return z.detach()
+
+
+def ce(z, y):
+ return F.cross_entropy((z @ Whead).reshape(-1, vocab), y.reshape(-1))
+
+
+def aep_grad(x_in, y, T1, T2, eps, beta, aep):
+ zs = relax(embed_in.clone(), x_in, T1, eps)
+ def nudged(sign):
+ z = zs.clone()
+ for _ in range(T2):
+ with torch.enable_grad():
+ zz = z.detach().requires_grad_(True)
+ g, = torch.autograd.grad(ce(zz, y), zz)
+ with torch.no_grad():
+ f = force(z, x_in) - sign * beta * g
+ if aep:
+ v = (z - zs).detach()
+ Jv = torch.autograd.functional.jvp(causal_attn, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(causal_attn, zs, v)[1]
+ corr = S * (Jv - JTv)
+ cn, fn = corr.norm(), f.norm() + 1e-8
+ corr = corr * (fn / cn) if cn > fn else corr
+ f = f - corr
+ z = z + eps * f
+ return z.detach()
+ zp, zm = nudged(+1.0), nudged(-1.0)
+ a = ((zm - zp) / (2 * beta)).detach()
+ with torch.enable_grad():
+ s = (a * force(zs.detach(), x_in)).sum()
+ return torch.autograd.grad(s, list(P.values()), allow_unused=True)
+
+
+def bptt_grad(x_in, y, T1, eps):
+ z = embed_in.clone().requires_grad_(True)
+ for _ in range(T1):
+ z = z + eps * force(z, x_in)
+ return torch.autograd.grad(ce(z, y), list(P.values()), allow_unused=True)
+
+
+def cos(g, gb):
+ cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)]
+ return sum(cs) / len(cs), cs
+
+
+idx, y = get_batch()
+x_in = embed(idx).detach()
+embed_in = x_in.clone() # init state = embedding
+
+# free-eq residual (is there a stable fixed point?)
+zs = relax(embed_in.clone(), x_in, 200, 0.1)
+r = (relax(zs, x_in, 1, 0.1) - zs).norm().item() / (zs.norm().item() + 1e-9)
+print(f"LM causal damped-equilibrium attention on Shakespeare (B={B} T={T} C={C} H={H}, damp={DAMP})")
+print(f" free-phase residual = {r:.2e} ({'stable fixed point' if r < 1e-2 else 'NOT converged'})")
+
+gb = bptt_grad(x_in, y, 120, 0.1)
+gn = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=False)
+ga = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=True)
+mn, csn = cos(gn, gb)
+ma, csa = cos(ga, gb)
+print(f"\n attention-param gradient cosine vs BPTT:")
+print(f" naive-EP : mean {mn:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csn)))
+print(f" AEP : mean {ma:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csa)))