diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /scripts/aep_characterize.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'scripts/aep_characterize.py')
| -rw-r--r-- | scripts/aep_characterize.py | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/scripts/aep_characterize.py b/scripts/aep_characterize.py new file mode 100644 index 0000000..3e642c1 --- /dev/null +++ b/scripts/aep_characterize.py @@ -0,0 +1,157 @@ +""" +Characterize AEP (non-conservative EP) on CET's attention, before porting to the LM. + +Controlled knob: attention scale s in force_z = -dE_rest/dz + s * RealAttn(z). + s=0 -> pure conservative reconstruction (A_J=0; EP exact) + s up -> attention dominates the force -> more non-conservative -> naive EP biased. +Metric: cosine(EP-grad, BPTT-grad) on the attention params {WQ,WK,WV,WO} (the global +cosine is diluted by the dominant conservative params, so we look at attention itself). +The AEP correction is -s*(J_A v) on z, J_A = antisym Jacobian of RealAttn at the free eq. + +Sweeps: (1) s [non-conservativeness], (2) beta [nudge size], (3) T2 [nudge steps], + (4) T1 [free-phase convergence]. Plus: free-eq identical naive vs AEP, and cost. +""" +import argparse, math, time, torch, torch.nn.functional as F +from cet_mvp import make_patch_mask, masked_cost, get_loaders +from cet_aep import CETReal + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +ATTN = ('WQ', 'WK', 'WV', 'WO') + + +def force(model, xbar, z, y, s): + z = z.requires_grad_(True); y = y.requires_grad_(True) + gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True) + return -gz + s * model.real_attn(z), -gy + + +def relax_free(model, xbar, z, y, s, T1, eps): + for _ in range(T1): + with torch.enable_grad(): + fz, fy = force(model, xbar, z, y, s) + fz, fy = fz.detach(), fy.detach() + with torch.no_grad(): + z, y = z + eps * fz, y + eps * fy + return z.detach(), y.detach() + + +def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep): + z, y = zs.clone(), ys.clone() + for _ in range(T2): + with torch.enable_grad(): + fz, fy = force(model, xbar, z, y, s) + fz, fy = fz.detach(), fy.detach() + yy = y.detach().requires_grad_(True) + gy, = torch.autograd.grad(masked_cost(yy, X, M), yy) + fy = fy - sign * beta * gy + if aep: + v = (z - zs).detach() + Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1] + JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1] + fz = fz - s * (Jv - JTv) # -2 * s * 0.5 (J v - J^T v) + with torch.no_grad(): + z, y = z + eps * fz, y + eps * fy + return z.detach(), y.detach() + + +def vf_grad(model, xbar, s, T1, T2, eps, beta, aep): + zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps) + zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep) + zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep) + az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach() + with torch.enable_grad(): + fz, fy = force(model, xbar, zs.detach(), ys.detach(), s) + g = torch.autograd.grad((az * fz).sum() + (ay * fy).sum(), + list(model.parameters()), allow_unused=True) + return zs, g + + +def bptt_grad(model, xbar, s, T1, eps): + z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True) + for _ in range(T1): + fz, fy = force(model, xbar, z, y, s) + z, y = z + eps * fz, y + eps * fy + return torch.autograd.grad(masked_cost(y, X, M) / M.sum(), + list(model.parameters()), allow_unused=True) + + +def attn_cos(g, gb, names): + cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None] + return sum(cs) / len(cs) + + +def global_cos(g, gb): + a = torch.cat([x.flatten() for x in g if x is not None]) + b = torch.cat([x.flatten() for x, y in zip(g, gb) if x is not None and y is not None]) + return F.cosine_similarity(a, b, dim=0).item() + + +def measure(model, names, s, T1, T2, eps, beta): + gb = bptt_grad(model, XBAR, s, T1, eps) + zsn, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=False) + zsa, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=True) + eq_id = (zsn - zsa).norm().item() / (zsn.norm().item() + 1e-9) # free eq identical? + return dict(naive=attn_cos(gn, gb, names), aep=attn_cos(ga, gb, names), + gnaive=global_cos(gn, gb), gaep=global_cos(ga, gb), eq_id=eq_id) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--dataset', default='fashionmnist') + ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1) + ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7) + ap.add_argument('--batch', type=int, default=32) + cfg = ap.parse_args() + torch.manual_seed(0) + model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=64, heads=4, dh=16, mem=128).to(dev) + names = [n for n, _ in model.named_parameters()] + trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset) + global X, M, XBAR + X, _ = next(iter(trl)); X = X.to(dev) + M = make_patch_mask(X.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) + XBAR = X * (1 - M) + + # intrinsic non-conservativeness of the attention map itself + zs, _ = relax_free(model, XBAR, *model.init_state(XBAR), 1.0, 120, 0.2) + v = torch.randn_like(zs) + Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1] + JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1] + print(f"intrinsic attention-map antisymmetry ||A_J v||/||J v|| = " + f"{(0.5*(Jv-JTv)).norm().item()/(Jv.norm().item()+1e-9):.3f}") + + base = dict(T1=120, T2=20, eps=0.2, beta=0.02) + print("\n[1] ATTENTION SCALE s (s=0 conservative -> larger = more non-conservative)") + print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'naive(glob)':>11} {'AEP(glob)':>10} | free-eq id") + for s in [0.25, 0.5, 1.0, 2.0, 4.0, 8.0]: + r = measure(model, names, s, base['T1'], base['T2'], base['eps'], base['beta']) + print(f"{s:6.2f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {r['gnaive']:>11.4f} {r['gaep']:>10.4f} | {r['eq_id']:.1e}") + + print("\n[2] NUDGE STRENGTH beta (s=2, T2=20)") + print(f"{'beta':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") + for beta in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2]: + r = measure(model, names, 2.0, 120, 20, 0.2, beta) + print(f"{beta:6.3f} | {r['naive']:>11.3f} {r['aep']:>10.3f}") + + print("\n[3] NUDGE STEPS T2 (s=2, beta=0.02)") + print(f"{'T2':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") + for T2 in [3, 5, 10, 20, 40]: + r = measure(model, names, 2.0, 120, T2, 0.2, 0.02) + print(f"{T2:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}") + + print("\n[4] FREE-PHASE STEPS T1 (s=2; AEP uses A_J at the free eq)") + print(f"{'T1':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") + for T1 in [20, 40, 80, 120, 200]: + r = measure(model, names, 2.0, T1, 20, 0.2, 0.02) + print(f"{T1:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}") + + print("\n[5] COST (s=2, T1=120, T2=20)") + t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=False) for _ in range(3)] + torch.cuda.synchronize() if dev == 'cuda' else None; tn = (time.time()-t)/3 + t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=True) for _ in range(3)] + torch.cuda.synchronize() if dev == 'cuda' else None; ta = (time.time()-t)/3 + print(f" naive {tn*1000:.0f} ms/grad AEP {ta*1000:.0f} ms/grad overhead {ta/tn:.2f}x") + + +if __name__ == '__main__': + main() |
