diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/lt_ep_stack.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/lt_ep_stack.py')
| -rw-r--r-- | ep_run/lt_ep_stack.py | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/ep_run/lt_ep_stack.py b/ep_run/lt_ep_stack.py new file mode 100644 index 0000000..327b143 --- /dev/null +++ b/ep_run/lt_ep_stack.py @@ -0,0 +1,165 @@ +"""Spring-coupled equilibrium STACK — EP through depth with no protocol. +Inter-block coupling is a conservative spring energy sum_k gamma/2 ||z_k - z_{k-1}||^2 (z_0 +sprung to the input clamp x). The cost pulls z_K; spring REACTION forces (Newton's 3rd law) +carry the tension down the chain; EP/VF + AEP correction on the non-conservative block internals +is unchanged — the stack is just one bigger force field. +Probe: per-block gradient cosine vs BPTT-through-the-joint-relaxation. The decisive number is +block-0's cosine: did the tension reach the bottom?""" +import math, time, torch, torch.nn.functional as F +from lt_ep_train import get_batch, vocab +dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + +class EQStack: + def __init__(self, K, C, H, T, gamma=1.0, c=1.0): + g = lambda *sh, sc: (torch.randn(*sh, device=dev) * sc).requires_grad_(True) + z1 = lambda n, v: torch.full((n,), float(v), device=dev).requires_grad_(True) + self.K, self.C, self.H, self.dh, self.T = K, C, H, C // H, T + self.gamma, self.c = gamma, c + self.tok = g(vocab, C, sc=0.02); self.pos = g(T, C, sc=0.02) + self.blocks = [] + for _ in range(K): + self.blocks.append(dict( + WQ=g(C, C, sc=1 / math.sqrt(C)), WK=g(C, C, sc=1 / math.sqrt(C)), + WV=g(C, C, sc=1 / math.sqrt(C)), WO=g(C, C, sc=1 / math.sqrt(C)), + ln1g=z1(C, 1), ln1b=z1(C, 0), ln2g=z1(C, 1), ln2b=z1(C, 0), + fc=g(C, 4 * C, sc=1 / math.sqrt(C)), fcb=z1(4 * C, 0), + pj=g(4 * C, C, sc=1 / math.sqrt(4 * C)), pjb=z1(C, 0))) + self.Wh = g(C, vocab, sc=1 / math.sqrt(C)) + self.cmask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev)) + self.block = [self.tok, self.pos] + [p for b in self.blocks for p in b.values()] + self.allp = self.block + [self.Wh] + + def embed(self, idx): + return self.tok[idx] + self.pos[None] + + def battn(self, b, z): + B, T, H, dh, C = z.size(0), self.T, self.H, self.dh, self.C + h = F.layer_norm(z, (C,), b['ln1g'], b['ln1b']) + q = (h @ b['WQ']).view(B, T, H, dh).transpose(1, 2) + k = (h @ b['WK']).view(B, T, H, dh).transpose(1, 2) + v = (h @ b['WV']).view(B, T, H, dh).transpose(1, 2) + a = torch.softmax(((q @ k.transpose(-2, -1)) / math.sqrt(dh)).masked_fill(~self.cmask, float('-inf')), -1) + return (a @ v).transpose(1, 2).reshape(B, T, C) @ b['WO'] + + def bffn(self, b, z): + h = F.layer_norm(z, (self.C,), b['ln2g'], b['ln2b']) + return F.gelu(h @ b['fc'] + b['fcb']) @ b['pj'] + b['pjb'] + + def nc_force(self, zc): # non-conservative internals, state (K,B,T,C) + return torch.stack([self.battn(b, zc[k]) + self.bffn(b, zc[k]) + for k, b in enumerate(self.blocks)], 0) + + def force(self, zc, xin, cg=False): + zr = zc if (cg and zc.requires_grad) else zc.detach().requires_grad_(True) + below = torch.cat([xin[None], zr[:-1]], 0) + f = -self.gamma * (zr - below) + self.nc_force(zr) - self.c * zr + up = self.gamma * (zr[1:] - zr[:-1]) # reaction of the spring above (Newton's 3rd law) + return f + torch.cat([up, torch.zeros_like(zr[:1])], 0) + + +def relax(st, z, xin, steps, eps): + for _ in range(steps): + with torch.no_grad(): + z = z + eps * st.force(z, xin).detach() + return z.detach() + + +def ce(st, z, y): + return F.cross_entropy((z[-1] @ st.Wh).reshape(-1, vocab), y.reshape(-1)) + + +def grad_ce_state(st, z, y): # closed-form dCE/dz: only the top block feels y + p = torch.softmax(z[-1] @ st.Wh, -1) + gK = (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ st.Wh.t() / y.numel() + g = torch.zeros_like(z) + g[-1] = gK + return g + + +def ep_a(st, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0): + """N=2 real phases, clamp-free, AEP corr on the stack's nc part, hindsight selection.""" + Zp, Zm = zs.clone(), zs.clone() + a_prev = a_best = None + inc_min, t_best = float('inf'), 0 + for t in range(1, T2max + 1): + with torch.no_grad(): + for Z, sg in ((Zp, +r), (Zm, -r)): + f = st.force(Z, xin) - sg * grad_ce_state(st, Z, y) + v = (Z - zs).contiguous() + Jv = torch.autograd.functional.jvp(st.nc_force, zs, v)[1] + JTv = torch.autograd.functional.vjp(st.nc_force, zs, v)[1] + Z += eps * (f - (Jv - JTv)) + if t % K == 0 or t == T2max: + a_t = (Zm - Zp) / (2 * r) + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + return (a_best if a_best is not None else a_prev).detach(), t_best + + +def ep_grads(st, idx, y, T1, eps, r, T2max): + xin = st.embed(idx).detach() + z0 = xin[None].repeat(st.K, 1, 1, 1) + zs = relax(st, z0, xin, T1, eps) + res = (relax(st, zs, xin, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) + a, tb = ep_a(st, zs, xin, y, r, T2max, eps) + with torch.enable_grad(): + x2 = st.embed(idx) + f = st.force(zs.detach(), x2, cg=True) + g = torch.autograd.grad((a * f).sum(), st.block, allow_unused=True) + return {id(p): gv for p, gv in zip(st.block, g)}, res, tb + + +def bptt_grads(st, idx, y, T1, eps): + xin = st.embed(idx) + z = (xin.detach().requires_grad_(True) * 0 + xin)[None].repeat(st.K, 1, 1, 1) + for _ in range(T1): + z = z + eps * st.force(z, xin, cg=True) + g = torch.autograd.grad(ce(st, z, y), st.allp, allow_unused=True) + return {id(p): gv for p, gv in zip(st.allp, g)} + + +if __name__ == '__main__': + torch.manual_seed(0) + K, B, T, C, H = 2, 16, 64, 128, 4 + st = EQStack(K, C, H, T, gamma=1.0, c=1.0) + opt = torch.optim.AdamW(st.allp, lr=1e-3, weight_decay=1e-4) + for step in range(200): # short BPTT pretrain -> realistic operating point + idx, y = get_batch('train', B, T) + g = bptt_grads(st, idx, y, 120, 0.1) + opt.zero_grad(set_to_none=True) + for p in st.allp: + p.grad = g.get(id(p)) + torch.nn.utils.clip_grad_norm_(st.allp, 5.0) + opt.step() + print(f"pretrained 200 BPTT steps (K={K} spring stack, gamma={st.gamma})", flush=True) + + groups = {'all': st.block, + 'blk0': list(st.blocks[0].values()), + 'blk1': list(st.blocks[1].values()), + 'emb': [st.tok, st.pos]} + + def cos(ga, gb, ps): + keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] + if not keep: + return float('nan') + va = torch.cat([ga[id(p)].reshape(-1) for p in keep]) + vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + + hdr = f"{'config':>22} {'res':>9} {'t_best':>7} " + " ".join(f"{k:>6}" for k in groups) + for bi in range(3): + idx, y = get_batch('train', B, T) + ref = bptt_grads(st, idx, y, 400, 0.1) + print(("\n" if bi else "") + hdr, flush=True) + for T1 in (150, 400): + gep, res, tb = ep_grads(st, idx, y, T1, 0.1, 0.02, 120) + print(f"{f'ep T1={T1}':>22} {res:>9.1e} {tb:>7} " + + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) |
