diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /scripts/aep_option1.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'scripts/aep_option1.py')
| -rw-r--r-- | scripts/aep_option1.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/scripts/aep_option1.py b/scripts/aep_option1.py new file mode 100644 index 0000000..65583d4 --- /dev/null +++ b/scripts/aep_option1.py @@ -0,0 +1,115 @@ +"""option 1: CORRECT gradient for non-conservative attention UNDER the token-norm constraint. + +Implicit differentiation of the projected fixed-point map G(x) = Pi(x + eps F(x)): + adjoint a <- J_G^T a + g , J_G^T = (I + eps J_F^T) Pi'^T , g = dC/dx* + gradient dL/dtheta = eps * < Pi'^T a , dF/dtheta(x*) > + +Built from LOCAL pieces only (this is the projected analogue of EP's nudged adjoint): + Pi'^T : vjp(token_norm, u, .) (the LN/projection Jacobian = LayerNormProjectedSurrogate) + J_F^T : -Hess(E_rest).b (symmetric, via HVP) + s * vjp(real_attn, z*, .) (the non-conservative bit) +Validation: cosine vs BPTT-through-the-projected-relaxation (ground truth). C lost fidelity; this should recover it. +""" +import torch, torch.nn.functional as F, math +from cet_mvp import token_norm, make_patch_mask, masked_cost, get_loaders +from cet_aep import CETReal + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +ATTN = ('WQ', 'WK', 'WV', 'WO') + + +def force(model, xbar, z, y, s, cg=False): + gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=cg) + return -gz + s * model.real_attn(z), -gy + + +def relax_free(model, xbar, z, y, s, T1, eps): + for _ in range(T1): + with torch.enable_grad(): + zr, yr = z.requires_grad_(True), y.requires_grad_(True) + fz, fy = force(model, xbar, zr, yr, s) + fz, fy = fz.detach(), fy.detach() + with torch.no_grad(): + z, y = token_norm(z + eps * fz), y + eps * fy + return z.detach(), y.detach() + + +def adjoint_grad(model, xbar, s, T1, eps, Tadj): + zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps) + # pre-projection point u for Pi' ; cost grad g=(0, dC/dy) + zr, yr = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True) + fz, fy = force(model, xbar, zr, yr, s) + uz = (zs + eps * fz).detach() + yc = ys.detach().requires_grad_(True) + gy_c, = torch.autograd.grad(masked_cost(yc, X, M) / M.sum(), yc) + gy_c = gy_c.detach() + + az, ay = torch.zeros_like(zs), gy_c.clone() # init adjoint at g (cost grad) + for _ in range(Tadj): + bz = torch.autograd.functional.vjp(token_norm, uz, az)[1] # Pi'^T a (z); y identity + by = ay + # J_F^T b = -Hess(E_rest).b + s * vjp(real_attn, zs, bz) + zr2, yr2 = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True) + gz2, gy2 = torch.autograd.grad(model.E_rest(xbar, zr2, yr2), [zr2, yr2], create_graph=True) + hz, hy = torch.autograd.grad((gz2 * bz).sum() + (gy2 * by).sum(), [zr2, yr2]) + av = torch.autograd.functional.vjp(model.real_attn, zs, bz)[1] + JFt_z, JFt_y = -hz + s * av, -hy + az = (bz + eps * JFt_z + torch.zeros_like(zs)).detach() + ay = (by + eps * JFt_y + gy_c).detach() + + # gradient: eps * d/dtheta < Pi'^T a , F(x*, theta) > + bz = torch.autograd.functional.vjp(token_norm, uz, az)[1].detach() + by = ay.detach() + zr3, yr3 = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True) + gz3, gy3 = torch.autograd.grad(model.E_rest(xbar, zr3, yr3), [zr3, yr3], create_graph=True) + Fz = -gz3 + s * model.real_attn(zr3) + Fy = -gy3 + contr = eps * ((bz * Fz).sum() + (by * Fy).sum()) + return torch.autograd.grad(contr, list(model.parameters()), allow_unused=True) + + +def bptt_grad(model, xbar, s, T1, eps): + z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True) + for _ in range(T1): + fz, fy = force(model, xbar, z, y, s, cg=True) + z, y = token_norm(z + eps * fz), y + eps * fy + return torch.autograd.grad(masked_cost(y, X, M) / M.sum(), + list(model.parameters()), allow_unused=True) + + +def cosines(g, gb, names): + c = lambda a, b: F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + at = [c(a, b) for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None] + A = torch.cat([x.flatten() for x in g if x is not None]) + B = torch.cat([y.flatten() for x, y in zip(g, gb) if x is not None and y is not None]) + return (sum(at) / len(at) if at else float('nan')), c(A, B) + + +def main(): + torch.manual_seed(0) + model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) + names = [n for n, _ in model.named_parameters()] + trl, _ = get_loaders(32, dataset='fashionmnist') + global X, M, XBAR + X, _ = next(iter(trl)); X = X.to(dev) + M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) + XBAR = X * (1 - M) + def resid(s, T1, eps=0.2): + zs, ys = relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps) + with torch.enable_grad(): + zr, yr = zs.requires_grad_(True), ys.requires_grad_(True) + fz, fy = force(model, XBAR, zr, yr, s) + zn = token_norm(zs + eps * fz.detach()) + return ((zn - zs).norm() / (zs.norm() + 1e-9)).item() + + print("PROJECTED-ADJOINT (option 1) vs BPTT — is the s>=2 break convergence or no-fixed-point?") + print(f"{'s':>5} {'T1=Tadj':>8} | {'attn cos':>9} {'glob cos':>9} | {'fwd resid':>9}") + for s in [0.5, 1.0, 2.0]: + for it in [120, 400]: + gb = bptt_grad(model, XBAR, s, it, 0.2) + ga = adjoint_grad(model, XBAR, s, it, 0.2, it) + a, g = cosines(ga, gb, names) + print(f"{s:5.1f} {it:>8} | {a:>9.3f} {g:>9.3f} | {resid(s, it):>9.2e}") + + +if __name__ == '__main__': + main() |
