summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_stack.py
blob: 327b143e38fc2c417e5f80013fa98265cfe38442 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Spring-coupled equilibrium STACK — EP through depth with no protocol.
Inter-block coupling is a conservative spring energy  sum_k gamma/2 ||z_k - z_{k-1}||^2  (z_0
sprung to the input clamp x).  The cost pulls z_K; spring REACTION forces (Newton's 3rd law)
carry the tension down the chain; EP/VF + AEP correction on the non-conservative block internals
is unchanged — the stack is just one bigger force field.
Probe: per-block gradient cosine vs BPTT-through-the-joint-relaxation. The decisive number is
block-0's cosine: did the tension reach the bottom?"""
import math, time, torch, torch.nn.functional as F
from lt_ep_train import get_batch, vocab
dev = 'cuda' if torch.cuda.is_available() else 'cpu'


class EQStack:
    def __init__(self, K, C, H, T, gamma=1.0, c=1.0):
        g = lambda *sh, sc: (torch.randn(*sh, device=dev) * sc).requires_grad_(True)
        z1 = lambda n, v: torch.full((n,), float(v), device=dev).requires_grad_(True)
        self.K, self.C, self.H, self.dh, self.T = K, C, H, C // H, T
        self.gamma, self.c = gamma, c
        self.tok = g(vocab, C, sc=0.02); self.pos = g(T, C, sc=0.02)
        self.blocks = []
        for _ in range(K):
            self.blocks.append(dict(
                WQ=g(C, C, sc=1 / math.sqrt(C)), WK=g(C, C, sc=1 / math.sqrt(C)),
                WV=g(C, C, sc=1 / math.sqrt(C)), WO=g(C, C, sc=1 / math.sqrt(C)),
                ln1g=z1(C, 1), ln1b=z1(C, 0), ln2g=z1(C, 1), ln2b=z1(C, 0),
                fc=g(C, 4 * C, sc=1 / math.sqrt(C)), fcb=z1(4 * C, 0),
                pj=g(4 * C, C, sc=1 / math.sqrt(4 * C)), pjb=z1(C, 0)))
        self.Wh = g(C, vocab, sc=1 / math.sqrt(C))
        self.cmask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev))
        self.block = [self.tok, self.pos] + [p for b in self.blocks for p in b.values()]
        self.allp = self.block + [self.Wh]

    def embed(self, idx):
        return self.tok[idx] + self.pos[None]

    def battn(self, b, z):
        B, T, H, dh, C = z.size(0), self.T, self.H, self.dh, self.C
        h = F.layer_norm(z, (C,), b['ln1g'], b['ln1b'])
        q = (h @ b['WQ']).view(B, T, H, dh).transpose(1, 2)
        k = (h @ b['WK']).view(B, T, H, dh).transpose(1, 2)
        v = (h @ b['WV']).view(B, T, H, dh).transpose(1, 2)
        a = torch.softmax(((q @ k.transpose(-2, -1)) / math.sqrt(dh)).masked_fill(~self.cmask, float('-inf')), -1)
        return (a @ v).transpose(1, 2).reshape(B, T, C) @ b['WO']

    def bffn(self, b, z):
        h = F.layer_norm(z, (self.C,), b['ln2g'], b['ln2b'])
        return F.gelu(h @ b['fc'] + b['fcb']) @ b['pj'] + b['pjb']

    def nc_force(self, zc):                       # non-conservative internals, state (K,B,T,C)
        return torch.stack([self.battn(b, zc[k]) + self.bffn(b, zc[k])
                            for k, b in enumerate(self.blocks)], 0)

    def force(self, zc, xin, cg=False):
        zr = zc if (cg and zc.requires_grad) else zc.detach().requires_grad_(True)
        below = torch.cat([xin[None], zr[:-1]], 0)
        f = -self.gamma * (zr - below) + self.nc_force(zr) - self.c * zr
        up = self.gamma * (zr[1:] - zr[:-1])      # reaction of the spring above (Newton's 3rd law)
        return f + torch.cat([up, torch.zeros_like(zr[:1])], 0)


def relax(st, z, xin, steps, eps):
    for _ in range(steps):
        with torch.no_grad():
            z = z + eps * st.force(z, xin).detach()
    return z.detach()


def ce(st, z, y):
    return F.cross_entropy((z[-1] @ st.Wh).reshape(-1, vocab), y.reshape(-1))


def grad_ce_state(st, z, y):                       # closed-form dCE/dz: only the top block feels y
    p = torch.softmax(z[-1] @ st.Wh, -1)
    gK = (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ st.Wh.t() / y.numel()
    g = torch.zeros_like(z)
    g[-1] = gK
    return g


def ep_a(st, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0):
    """N=2 real phases, clamp-free, AEP corr on the stack's nc part, hindsight selection."""
    Zp, Zm = zs.clone(), zs.clone()
    a_prev = a_best = None
    inc_min, t_best = float('inf'), 0
    for t in range(1, T2max + 1):
        with torch.no_grad():
            for Z, sg in ((Zp, +r), (Zm, -r)):
                f = st.force(Z, xin) - sg * grad_ce_state(st, Z, y)
                v = (Z - zs).contiguous()
                Jv = torch.autograd.functional.jvp(st.nc_force, zs, v)[1]
                JTv = torch.autograd.functional.vjp(st.nc_force, zs, v)[1]
                Z += eps * (f - (Jv - JTv))
        if t % K == 0 or t == T2max:
            a_t = (Zm - Zp) / (2 * r)
            if not torch.isfinite(a_t).all():
                break
            if a_prev is not None:
                inc = (a_t - a_prev).norm().item()
                if inc < inc_min:
                    inc_min, a_best, t_best = inc, a_t, t
                elif inc > exit_mult * inc_min and t >= 3 * K:
                    break
            a_prev = a_t
    return (a_best if a_best is not None else a_prev).detach(), t_best


def ep_grads(st, idx, y, T1, eps, r, T2max):
    xin = st.embed(idx).detach()
    z0 = xin[None].repeat(st.K, 1, 1, 1)
    zs = relax(st, z0, xin, T1, eps)
    res = (relax(st, zs, xin, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9)
    a, tb = ep_a(st, zs, xin, y, r, T2max, eps)
    with torch.enable_grad():
        x2 = st.embed(idx)
        f = st.force(zs.detach(), x2, cg=True)
        g = torch.autograd.grad((a * f).sum(), st.block, allow_unused=True)
    return {id(p): gv for p, gv in zip(st.block, g)}, res, tb


def bptt_grads(st, idx, y, T1, eps):
    xin = st.embed(idx)
    z = (xin.detach().requires_grad_(True) * 0 + xin)[None].repeat(st.K, 1, 1, 1)
    for _ in range(T1):
        z = z + eps * st.force(z, xin, cg=True)
    g = torch.autograd.grad(ce(st, z, y), st.allp, allow_unused=True)
    return {id(p): gv for p, gv in zip(st.allp, g)}


if __name__ == '__main__':
    torch.manual_seed(0)
    K, B, T, C, H = 2, 16, 64, 128, 4
    st = EQStack(K, C, H, T, gamma=1.0, c=1.0)
    opt = torch.optim.AdamW(st.allp, lr=1e-3, weight_decay=1e-4)
    for step in range(200):                       # short BPTT pretrain -> realistic operating point
        idx, y = get_batch('train', B, T)
        g = bptt_grads(st, idx, y, 120, 0.1)
        opt.zero_grad(set_to_none=True)
        for p in st.allp:
            p.grad = g.get(id(p))
        torch.nn.utils.clip_grad_norm_(st.allp, 5.0)
        opt.step()
    print(f"pretrained 200 BPTT steps (K={K} spring stack, gamma={st.gamma})", flush=True)

    groups = {'all': st.block,
              'blk0': list(st.blocks[0].values()),
              'blk1': list(st.blocks[1].values()),
              'emb': [st.tok, st.pos]}

    def cos(ga, gb, ps):
        keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None]
        if not keep:
            return float('nan')
        va = torch.cat([ga[id(p)].reshape(-1) for p in keep])
        vb = torch.cat([gb[id(p)].reshape(-1) for p in keep])
        return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()

    hdr = f"{'config':>22} {'res':>9} {'t_best':>7} " + " ".join(f"{k:>6}" for k in groups)
    for bi in range(3):
        idx, y = get_batch('train', B, T)
        ref = bptt_grads(st, idx, y, 400, 0.1)
        print(("\n" if bi else "") + hdr, flush=True)
        for T1 in (150, 400):
            gep, res, tb = ep_grads(st, idx, y, T1, 0.1, 0.02, 120)
            print(f"{f'ep T1={T1}':>22} {res:>9.1e} {tb:>7} "
                  + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)