summaryrefslogtreecommitdiff
path: root/scripts/aep_contractive.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /scripts/aep_contractive.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'scripts/aep_contractive.py')
-rw-r--r--scripts/aep_contractive.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/scripts/aep_contractive.py b/scripts/aep_contractive.py
new file mode 100644
index 0000000..670feae
--- /dev/null
+++ b/scripts/aep_contractive.py
@@ -0,0 +1,52 @@
+"""F: make REAL attention EP-able by damping it into a contraction (keep it non-conservative).
+
+Attention term in the force becomes s*(attn(z) - c*z). The -c*z is damping that grows with s,
+pushing Re(eig(J_F)) < 0 (a stable fixed point) WITHOUT symmetrizing the Jacobian (the antisymmetric
+part is unchanged, so it stays non-conservative -> AEP still needed AND now applicable).
+
+We sweep (s, c) and report, using the validated projected-adjoint (option 1):
+ fwd resid : does a stable fixed point exist? (small = yes)
+ adj cos : projected-adjoint gradient fidelity vs BPTT on attention params
+Expected: c=0 breaks at high s (no fixed point, as before); c>=1 keeps resid small + fidelity high.
+"""
+import torch, aep_option1 as O
+from cet_aep import CETReal
+from cet_mvp import token_norm, make_patch_mask, get_loaders
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
+names = [n for n, _ in model.named_parameters()]
+orig_attn = model.real_attn # original (undamped) attention
+
+trl, _ = get_loaders(32, dataset='fashionmnist')
+X, _ = next(iter(trl)); X = X.to(dev)
+M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
+XBAR = X * (1 - M)
+O.X, O.M = X, M # masked_cost in option1 uses these globals
+
+
+def set_damp(c):
+ model.real_attn = orig_attn if c == 0 else (lambda z: orig_attn(z) - c * z)
+
+
+def resid(s, T1, eps=0.2):
+ zs, ys = O.relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps)
+ with torch.enable_grad():
+ zr, yr = zs.requires_grad_(True), ys.requires_grad_(True)
+ fz, _ = O.force(model, XBAR, zr, yr, s)
+ zn = token_norm(zs + eps * fz.detach())
+ return ((zn - zs).norm() / (zs.norm() + 1e-9)).item()
+
+
+print("Contractive (damped) non-conservative attention — does it restore a fixed point + EP fidelity?")
+print(f"{'s':>5} {'c':>4} | {'fwd resid':>9} {'adj cos(attn)':>13} {'glob':>7}")
+for s in [1.0, 2.0, 4.0, 8.0]:
+ for c in [0.0, 1.0, 2.0]:
+ set_damp(c)
+ r = resid(s, 250)
+ gb = O.bptt_grad(model, XBAR, s, 250, 0.2)
+ ga = O.adjoint_grad(model, XBAR, s, 250, 0.2, 250)
+ a, g = O.cosines(ga, gb, names)
+ print(f"{s:>5.1f} {c:>4.1f} | {r:>9.2e} {a:>13.3f} {g:>7.3f}")
+ print()