diff options
Diffstat (limited to 'ep_run/lt_ep_attention.py')
| -rw-r--r-- | ep_run/lt_ep_attention.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/ep_run/lt_ep_attention.py b/ep_run/lt_ep_attention.py new file mode 100644 index 0000000..c411de6 --- /dev/null +++ b/ep_run/lt_ep_attention.py @@ -0,0 +1,129 @@ +"""option 2 / MVP-A: AEP for the LM's CAUSAL attention (equilibrium reformulation), on real text. + +The LM's feedforward causal attention is reformulated as a damped equilibrium block: + state z (B,T,C); input embedding x_in is clamped (boundary); + force F(z) = -(z - x_in) + s*(causal_attn(z) - c*z) [c>0: contraction -> stable fixed pt] + causal_attn = O softmax(QK^T/sqrt d , causal) V [non-conservative: independent Q/K/V/O] +Relax to z*, read out logits = z* W_head, cost = next-token cross-entropy. +Train Q/K/V/O with AEP (free + +/-beta nudged, centered, correction clipped) vs naive-EP, and +compare each to the ground-truth BPTT gradient (cosine) on the attention params -- on Shakespeare. +""" +import math, pickle, numpy as np, torch, torch.nn.functional as F +from pathlib import Path + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +DD = Path('/tmp/lt_ep/data/shakespeare_char') +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] +B, T, C, H = 16, 64, 128, 4 +dh = C // H +ATTN = ('WQ', 'WK', 'WV', 'WO') + + +def get_batch(): + data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +P = dict(WQ=torch.randn(C, C, device=dev) / math.sqrt(C), + WK=torch.randn(C, C, device=dev) / math.sqrt(C), + WV=torch.randn(C, C, device=dev) / math.sqrt(C), + WO=torch.randn(C, C, device=dev) / math.sqrt(C)) +for v in P.values(): + v.requires_grad_(True) +Whead = (torch.randn(C, vocab, device=dev) / math.sqrt(C)).requires_grad_(True) +tok_emb = torch.randn(vocab, C, device=dev) * 0.02 # fixed embedding (we test Q/K/V/O grads) +pos_emb = torch.randn(T, C, device=dev) * 0.02 +CMASK = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev)) +DAMP, S = 1.0, 1.0 + + +def embed(idx): + return tok_emb[idx] + pos_emb[None, :, :] + + +def causal_attn(z): + q = (z @ P['WQ']).view(B, T, H, dh).transpose(1, 2) + k = (z @ P['WK']).view(B, T, H, dh).transpose(1, 2) + v = (z @ P['WV']).view(B, T, H, dh).transpose(1, 2) + a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) + a = a.masked_fill(~CMASK[:T, :T], float('-inf')) + a = torch.softmax(a, dim=-1) + return (a @ v).transpose(1, 2).reshape(B, T, C) @ P['WO'] + + +def force(z, x_in): + return -(z - x_in) + S * (causal_attn(z) - DAMP * z) + + +def relax(z, x_in, steps, eps): + for _ in range(steps): + with torch.no_grad(): + z = z + eps * force(z, x_in) + return z.detach() + + +def ce(z, y): + return F.cross_entropy((z @ Whead).reshape(-1, vocab), y.reshape(-1)) + + +def aep_grad(x_in, y, T1, T2, eps, beta, aep): + zs = relax(embed_in.clone(), x_in, T1, eps) + def nudged(sign): + z = zs.clone() + for _ in range(T2): + with torch.enable_grad(): + zz = z.detach().requires_grad_(True) + g, = torch.autograd.grad(ce(zz, y), zz) + with torch.no_grad(): + f = force(z, x_in) - sign * beta * g + if aep: + v = (z - zs).detach() + Jv = torch.autograd.functional.jvp(causal_attn, zs, v)[1] + JTv = torch.autograd.functional.vjp(causal_attn, zs, v)[1] + corr = S * (Jv - JTv) + cn, fn = corr.norm(), f.norm() + 1e-8 + corr = corr * (fn / cn) if cn > fn else corr + f = f - corr + z = z + eps * f + return z.detach() + zp, zm = nudged(+1.0), nudged(-1.0) + a = ((zm - zp) / (2 * beta)).detach() + with torch.enable_grad(): + s = (a * force(zs.detach(), x_in)).sum() + return torch.autograd.grad(s, list(P.values()), allow_unused=True) + + +def bptt_grad(x_in, y, T1, eps): + z = embed_in.clone().requires_grad_(True) + for _ in range(T1): + z = z + eps * force(z, x_in) + return torch.autograd.grad(ce(z, y), list(P.values()), allow_unused=True) + + +def cos(g, gb): + cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)] + return sum(cs) / len(cs), cs + + +idx, y = get_batch() +x_in = embed(idx).detach() +embed_in = x_in.clone() # init state = embedding + +# free-eq residual (is there a stable fixed point?) +zs = relax(embed_in.clone(), x_in, 200, 0.1) +r = (relax(zs, x_in, 1, 0.1) - zs).norm().item() / (zs.norm().item() + 1e-9) +print(f"LM causal damped-equilibrium attention on Shakespeare (B={B} T={T} C={C} H={H}, damp={DAMP})") +print(f" free-phase residual = {r:.2e} ({'stable fixed point' if r < 1e-2 else 'NOT converged'})") + +gb = bptt_grad(x_in, y, 120, 0.1) +gn = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=False) +ga = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=True) +mn, csn = cos(gn, gb) +ma, csa = cos(ga, gb) +print(f"\n attention-param gradient cosine vs BPTT:") +print(f" naive-EP : mean {mn:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csn))) +print(f" AEP : mean {ma:+.3f} per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csa))) |
