summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_stack.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/lt_ep_stack.py')
-rw-r--r--ep_run/lt_ep_stack.py165
1 files changed, 165 insertions, 0 deletions
diff --git a/ep_run/lt_ep_stack.py b/ep_run/lt_ep_stack.py
new file mode 100644
index 0000000..327b143
--- /dev/null
+++ b/ep_run/lt_ep_stack.py
@@ -0,0 +1,165 @@
+"""Spring-coupled equilibrium STACK — EP through depth with no protocol.
+Inter-block coupling is a conservative spring energy sum_k gamma/2 ||z_k - z_{k-1}||^2 (z_0
+sprung to the input clamp x). The cost pulls z_K; spring REACTION forces (Newton's 3rd law)
+carry the tension down the chain; EP/VF + AEP correction on the non-conservative block internals
+is unchanged — the stack is just one bigger force field.
+Probe: per-block gradient cosine vs BPTT-through-the-joint-relaxation. The decisive number is
+block-0's cosine: did the tension reach the bottom?"""
+import math, time, torch, torch.nn.functional as F
+from lt_ep_train import get_batch, vocab
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+
+class EQStack:
+ def __init__(self, K, C, H, T, gamma=1.0, c=1.0):
+ g = lambda *sh, sc: (torch.randn(*sh, device=dev) * sc).requires_grad_(True)
+ z1 = lambda n, v: torch.full((n,), float(v), device=dev).requires_grad_(True)
+ self.K, self.C, self.H, self.dh, self.T = K, C, H, C // H, T
+ self.gamma, self.c = gamma, c
+ self.tok = g(vocab, C, sc=0.02); self.pos = g(T, C, sc=0.02)
+ self.blocks = []
+ for _ in range(K):
+ self.blocks.append(dict(
+ WQ=g(C, C, sc=1 / math.sqrt(C)), WK=g(C, C, sc=1 / math.sqrt(C)),
+ WV=g(C, C, sc=1 / math.sqrt(C)), WO=g(C, C, sc=1 / math.sqrt(C)),
+ ln1g=z1(C, 1), ln1b=z1(C, 0), ln2g=z1(C, 1), ln2b=z1(C, 0),
+ fc=g(C, 4 * C, sc=1 / math.sqrt(C)), fcb=z1(4 * C, 0),
+ pj=g(4 * C, C, sc=1 / math.sqrt(4 * C)), pjb=z1(C, 0)))
+ self.Wh = g(C, vocab, sc=1 / math.sqrt(C))
+ self.cmask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev))
+ self.block = [self.tok, self.pos] + [p for b in self.blocks for p in b.values()]
+ self.allp = self.block + [self.Wh]
+
+ def embed(self, idx):
+ return self.tok[idx] + self.pos[None]
+
+ def battn(self, b, z):
+ B, T, H, dh, C = z.size(0), self.T, self.H, self.dh, self.C
+ h = F.layer_norm(z, (C,), b['ln1g'], b['ln1b'])
+ q = (h @ b['WQ']).view(B, T, H, dh).transpose(1, 2)
+ k = (h @ b['WK']).view(B, T, H, dh).transpose(1, 2)
+ v = (h @ b['WV']).view(B, T, H, dh).transpose(1, 2)
+ a = torch.softmax(((q @ k.transpose(-2, -1)) / math.sqrt(dh)).masked_fill(~self.cmask, float('-inf')), -1)
+ return (a @ v).transpose(1, 2).reshape(B, T, C) @ b['WO']
+
+ def bffn(self, b, z):
+ h = F.layer_norm(z, (self.C,), b['ln2g'], b['ln2b'])
+ return F.gelu(h @ b['fc'] + b['fcb']) @ b['pj'] + b['pjb']
+
+ def nc_force(self, zc): # non-conservative internals, state (K,B,T,C)
+ return torch.stack([self.battn(b, zc[k]) + self.bffn(b, zc[k])
+ for k, b in enumerate(self.blocks)], 0)
+
+ def force(self, zc, xin, cg=False):
+ zr = zc if (cg and zc.requires_grad) else zc.detach().requires_grad_(True)
+ below = torch.cat([xin[None], zr[:-1]], 0)
+ f = -self.gamma * (zr - below) + self.nc_force(zr) - self.c * zr
+ up = self.gamma * (zr[1:] - zr[:-1]) # reaction of the spring above (Newton's 3rd law)
+ return f + torch.cat([up, torch.zeros_like(zr[:1])], 0)
+
+
+def relax(st, z, xin, steps, eps):
+ for _ in range(steps):
+ with torch.no_grad():
+ z = z + eps * st.force(z, xin).detach()
+ return z.detach()
+
+
+def ce(st, z, y):
+ return F.cross_entropy((z[-1] @ st.Wh).reshape(-1, vocab), y.reshape(-1))
+
+
+def grad_ce_state(st, z, y): # closed-form dCE/dz: only the top block feels y
+ p = torch.softmax(z[-1] @ st.Wh, -1)
+ gK = (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ st.Wh.t() / y.numel()
+ g = torch.zeros_like(z)
+ g[-1] = gK
+ return g
+
+
+def ep_a(st, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0):
+ """N=2 real phases, clamp-free, AEP corr on the stack's nc part, hindsight selection."""
+ Zp, Zm = zs.clone(), zs.clone()
+ a_prev = a_best = None
+ inc_min, t_best = float('inf'), 0
+ for t in range(1, T2max + 1):
+ with torch.no_grad():
+ for Z, sg in ((Zp, +r), (Zm, -r)):
+ f = st.force(Z, xin) - sg * grad_ce_state(st, Z, y)
+ v = (Z - zs).contiguous()
+ Jv = torch.autograd.functional.jvp(st.nc_force, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(st.nc_force, zs, v)[1]
+ Z += eps * (f - (Jv - JTv))
+ if t % K == 0 or t == T2max:
+ a_t = (Zm - Zp) / (2 * r)
+ if not torch.isfinite(a_t).all():
+ break
+ if a_prev is not None:
+ inc = (a_t - a_prev).norm().item()
+ if inc < inc_min:
+ inc_min, a_best, t_best = inc, a_t, t
+ elif inc > exit_mult * inc_min and t >= 3 * K:
+ break
+ a_prev = a_t
+ return (a_best if a_best is not None else a_prev).detach(), t_best
+
+
+def ep_grads(st, idx, y, T1, eps, r, T2max):
+ xin = st.embed(idx).detach()
+ z0 = xin[None].repeat(st.K, 1, 1, 1)
+ zs = relax(st, z0, xin, T1, eps)
+ res = (relax(st, zs, xin, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9)
+ a, tb = ep_a(st, zs, xin, y, r, T2max, eps)
+ with torch.enable_grad():
+ x2 = st.embed(idx)
+ f = st.force(zs.detach(), x2, cg=True)
+ g = torch.autograd.grad((a * f).sum(), st.block, allow_unused=True)
+ return {id(p): gv for p, gv in zip(st.block, g)}, res, tb
+
+
+def bptt_grads(st, idx, y, T1, eps):
+ xin = st.embed(idx)
+ z = (xin.detach().requires_grad_(True) * 0 + xin)[None].repeat(st.K, 1, 1, 1)
+ for _ in range(T1):
+ z = z + eps * st.force(z, xin, cg=True)
+ g = torch.autograd.grad(ce(st, z, y), st.allp, allow_unused=True)
+ return {id(p): gv for p, gv in zip(st.allp, g)}
+
+
+if __name__ == '__main__':
+ torch.manual_seed(0)
+ K, B, T, C, H = 2, 16, 64, 128, 4
+ st = EQStack(K, C, H, T, gamma=1.0, c=1.0)
+ opt = torch.optim.AdamW(st.allp, lr=1e-3, weight_decay=1e-4)
+ for step in range(200): # short BPTT pretrain -> realistic operating point
+ idx, y = get_batch('train', B, T)
+ g = bptt_grads(st, idx, y, 120, 0.1)
+ opt.zero_grad(set_to_none=True)
+ for p in st.allp:
+ p.grad = g.get(id(p))
+ torch.nn.utils.clip_grad_norm_(st.allp, 5.0)
+ opt.step()
+ print(f"pretrained 200 BPTT steps (K={K} spring stack, gamma={st.gamma})", flush=True)
+
+ groups = {'all': st.block,
+ 'blk0': list(st.blocks[0].values()),
+ 'blk1': list(st.blocks[1].values()),
+ 'emb': [st.tok, st.pos]}
+
+ def cos(ga, gb, ps):
+ keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None]
+ if not keep:
+ return float('nan')
+ va = torch.cat([ga[id(p)].reshape(-1) for p in keep])
+ vb = torch.cat([gb[id(p)].reshape(-1) for p in keep])
+ return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()
+
+ hdr = f"{'config':>22} {'res':>9} {'t_best':>7} " + " ".join(f"{k:>6}" for k in groups)
+ for bi in range(3):
+ idx, y = get_batch('train', B, T)
+ ref = bptt_grads(st, idx, y, 400, 0.1)
+ print(("\n" if bi else "") + hdr, flush=True)
+ for T1 in (150, 400):
+ gep, res, tb = ep_grads(st, idx, y, T1, 0.1, 0.02, 120)
+ print(f"{f'ep T1={T1}':>22} {res:>9.1e} {tb:>7} "
+ + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)