diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /scripts/aep_projected.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'scripts/aep_projected.py')
| -rw-r--r-- | scripts/aep_projected.py | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/scripts/aep_projected.py b/scripts/aep_projected.py new file mode 100644 index 0000000..af8891e --- /dev/null +++ b/scripts/aep_projected.py @@ -0,0 +1,125 @@ +"""C / option 1: PROJECTED AEP — non-conservative EP on the token-norm constraint manifold. + +Two fixes over the unconstrained version: + (1) STABILITY: relax with the token-norm projection z <- Pi(z + eps F) (bounds z; + this is what made plain CET stable). Lets large-s / deep attention stop diverging. + (2) CORRECT GRADIENT under the constraint: the VF contraction must be projected onto the + TANGENT space of the manifold. The tangent projector at a normalized token z is + P_z(v) = v - mean(v) - mean(v*z) * z + (exactly the local-transformer's LayerNormProjectedSurrogate). Without it the VF + estimator picks up the normal force and collapses (energy-mode cosine ~0.002). + +Param-gradient: dL/dtheta = <a_z, P_z*( dF_z/dtheta )> + <a_y, dF_y/dtheta>, + a = (state_-b - state_+b)/(2 beta). +AEP correction (nudged phase, on z): -s (J v - J^T v) of RealAttn, then projected. +""" +import argparse, math, torch, torch.nn.functional as F +from cet_mvp import token_norm, make_patch_mask, masked_cost, get_loaders +from cet_aep import CETReal + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +ATTN = ('WQ', 'WK', 'WV', 'WO') + + +def P_tan(z, v): # tangent projection at normalized token z + v = v - v.mean(-1, keepdim=True) + zz = (z * z).mean(-1, keepdim=True).clamp_min(1e-6) + return v - ((v * z).mean(-1, keepdim=True) / zz) * z + + +def force(model, xbar, z, y, s): + z = z.requires_grad_(True); y = y.requires_grad_(True) + gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True) + return -gz + s * model.real_attn(z), -gy + + +def relax_free(model, xbar, z, y, s, T1, eps): + for _ in range(T1): + with torch.enable_grad(): + fz, fy = force(model, xbar, z, y, s); fz, fy = fz.detach(), fy.detach() + with torch.no_grad(): + z = token_norm(z + eps * fz); y = y + eps * fy + return z.detach(), y.detach() + + +def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep): + z, y = zs.clone(), ys.clone() + for _ in range(T2): + with torch.enable_grad(): + fz, fy = force(model, xbar, z, y, s); fz, fy = fz.detach(), fy.detach() + yy = y.detach().requires_grad_(True) + gy, = torch.autograd.grad(masked_cost(yy, X, M), yy) + fy = fy - sign * beta * gy + if aep: + v = (z - zs).detach() + Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1] + JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1] + fz = fz - s * (Jv - JTv) + with torch.no_grad(): + z = token_norm(z + eps * fz); y = y + eps * fy + return z.detach(), y.detach() + + +def vf_grad(model, xbar, s, T1, T2, eps, beta, aep): + zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps) + zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep) + zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep) + az = P_tan(zs, ((zm - zp) / (2 * beta))).detach() # adjoint in tangent space + ay = ((ym - yp) / (2 * beta)).detach() + with torch.enable_grad(): + fz, fy = force(model, xbar, zs.detach(), ys.detach(), s) + s_ = (az * P_tan(zs, fz)).sum() + (ay * fy).sum() # projected contraction + g = torch.autograd.grad(s_, list(model.parameters()), allow_unused=True) + return zs, g + + +def bptt_grad(model, xbar, s, T1, eps): + z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True) + for _ in range(T1): + fz, fy = force(model, xbar, z, y, s) + z = token_norm(z + eps * fz); y = y + eps * fy + return torch.autograd.grad(masked_cost(y, X, M) / M.sum(), + list(model.parameters()), allow_unused=True) + + +def cosines(g, gb, names): + def c(a, b): return F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + at = [c(a, b) for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None] + A = torch.cat([x.flatten() for x in g if x is not None]) + B = torch.cat([y.flatten() for x, y in zip(g, gb) if x is not None and y is not None]) + return (sum(at) / len(at) if at else float('nan')), c(A, B) + + +def measure(model, names, s, T1, T2, eps, beta): + gb = bptt_grad(model, XBAR, s, T1, eps) + _, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, False) + zs, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, True) + an, gng = cosines(gn, gb, names) + aa, gag = cosines(ga, gb, names) + fin = torch.isfinite(zs).all().item() + return an, aa, gng, gag, fin + + +def main(): + torch.manual_seed(0) + model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) + names = [n for n, _ in model.named_parameters()] + trl, _ = get_loaders(32, dataset='fashionmnist') + global X, M, XBAR + X, _ = next(iter(trl)); X = X.to(dev) + M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) + XBAR = X * (1 - M) + + print("SANITY s=0 (pure conservative): projected-VF global cosine should be ~1") + _, _, gnaive, _, _ = measure(model, names, 0.0, 120, 20, 0.2, 0.02) + print(f" s=0 global cosine = {gnaive:.4f}\n") + + print("PROJECTED AEP across attention scale s (T1=120 T2=30 beta=0.02)") + print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'finite?':>7} (unproj. broke at s>=4)") + for s in [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]: + an, aa, gn, ga, fin = measure(model, names, s, 120, 30, 0.2, 0.02) + print(f"{s:6.2f} | {an:>11.3f} {aa:>10.3f} | {str(bool(fin)):>7}") + + +if __name__ == '__main__': + main() |
