summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_attention.py
blob: c411de6e4d15469bc8ffb892706f34f5eb0936b8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""option 2 / MVP-A: AEP for the LM's CAUSAL attention (equilibrium reformulation), on real text.

The LM's feedforward causal attention is reformulated as a damped equilibrium block:
  state z (B,T,C); input embedding x_in is clamped (boundary);
  force  F(z) = -(z - x_in) + s*(causal_attn(z) - c*z)         [c>0: contraction -> stable fixed pt]
  causal_attn = O softmax(QK^T/sqrt d , causal) V               [non-conservative: independent Q/K/V/O]
Relax to z*, read out logits = z* W_head, cost = next-token cross-entropy.
Train Q/K/V/O with AEP (free + +/-beta nudged, centered, correction clipped) vs naive-EP, and
compare each to the ground-truth BPTT gradient (cosine) on the attention params -- on Shakespeare.
"""
import math, pickle, numpy as np, torch, torch.nn.functional as F
from pathlib import Path

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
DD = Path('/tmp/lt_ep/data/shakespeare_char')
vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
B, T, C, H = 16, 64, 128, 4
dh = C // H
ATTN = ('WQ', 'WK', 'WV', 'WO')


def get_batch():
    data = np.memmap(DD / 'train.bin', dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - T - 1, (B,))
    x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
    return x.to(dev), y.to(dev)


P = dict(WQ=torch.randn(C, C, device=dev) / math.sqrt(C),
         WK=torch.randn(C, C, device=dev) / math.sqrt(C),
         WV=torch.randn(C, C, device=dev) / math.sqrt(C),
         WO=torch.randn(C, C, device=dev) / math.sqrt(C))
for v in P.values():
    v.requires_grad_(True)
Whead = (torch.randn(C, vocab, device=dev) / math.sqrt(C)).requires_grad_(True)
tok_emb = torch.randn(vocab, C, device=dev) * 0.02           # fixed embedding (we test Q/K/V/O grads)
pos_emb = torch.randn(T, C, device=dev) * 0.02
CMASK = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev))
DAMP, S = 1.0, 1.0


def embed(idx):
    return tok_emb[idx] + pos_emb[None, :, :]


def causal_attn(z):
    q = (z @ P['WQ']).view(B, T, H, dh).transpose(1, 2)
    k = (z @ P['WK']).view(B, T, H, dh).transpose(1, 2)
    v = (z @ P['WV']).view(B, T, H, dh).transpose(1, 2)
    a = (q @ k.transpose(-2, -1)) / math.sqrt(dh)
    a = a.masked_fill(~CMASK[:T, :T], float('-inf'))
    a = torch.softmax(a, dim=-1)
    return (a @ v).transpose(1, 2).reshape(B, T, C) @ P['WO']


def force(z, x_in):
    return -(z - x_in) + S * (causal_attn(z) - DAMP * z)


def relax(z, x_in, steps, eps):
    for _ in range(steps):
        with torch.no_grad():
            z = z + eps * force(z, x_in)
    return z.detach()


def ce(z, y):
    return F.cross_entropy((z @ Whead).reshape(-1, vocab), y.reshape(-1))


def aep_grad(x_in, y, T1, T2, eps, beta, aep):
    zs = relax(embed_in.clone(), x_in, T1, eps)
    def nudged(sign):
        z = zs.clone()
        for _ in range(T2):
            with torch.enable_grad():
                zz = z.detach().requires_grad_(True)
                g, = torch.autograd.grad(ce(zz, y), zz)
            with torch.no_grad():
                f = force(z, x_in) - sign * beta * g
                if aep:
                    v = (z - zs).detach()
                    Jv = torch.autograd.functional.jvp(causal_attn, zs, v)[1]
                    JTv = torch.autograd.functional.vjp(causal_attn, zs, v)[1]
                    corr = S * (Jv - JTv)
                    cn, fn = corr.norm(), f.norm() + 1e-8
                    corr = corr * (fn / cn) if cn > fn else corr
                    f = f - corr
                z = z + eps * f
        return z.detach()
    zp, zm = nudged(+1.0), nudged(-1.0)
    a = ((zm - zp) / (2 * beta)).detach()
    with torch.enable_grad():
        s = (a * force(zs.detach(), x_in)).sum()
        return torch.autograd.grad(s, list(P.values()), allow_unused=True)


def bptt_grad(x_in, y, T1, eps):
    z = embed_in.clone().requires_grad_(True)
    for _ in range(T1):
        z = z + eps * force(z, x_in)
    return torch.autograd.grad(ce(z, y), list(P.values()), allow_unused=True)


def cos(g, gb):
    cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() for a, b in zip(g, gb)]
    return sum(cs) / len(cs), cs


idx, y = get_batch()
x_in = embed(idx).detach()
embed_in = x_in.clone()                      # init state = embedding

# free-eq residual (is there a stable fixed point?)
zs = relax(embed_in.clone(), x_in, 200, 0.1)
r = (relax(zs, x_in, 1, 0.1) - zs).norm().item() / (zs.norm().item() + 1e-9)
print(f"LM causal damped-equilibrium attention on Shakespeare  (B={B} T={T} C={C} H={H}, damp={DAMP})")
print(f"  free-phase residual = {r:.2e}  ({'stable fixed point' if r < 1e-2 else 'NOT converged'})")

gb = bptt_grad(x_in, y, 120, 0.1)
gn = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=False)
ga = aep_grad(x_in, y, 120, 20, 0.1, 0.02, aep=True)
mn, csn = cos(gn, gb)
ma, csa = cos(ga, gb)
print(f"\n  attention-param gradient cosine vs BPTT:")
print(f"    naive-EP : mean {mn:+.3f}   per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csn)))
print(f"    AEP      : mean {ma:+.3f}   per-param " + " ".join(f"{n}={c:+.2f}" for n, c in zip(ATTN, csa)))