diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /scripts/aep_contractive.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'scripts/aep_contractive.py')
| -rw-r--r-- | scripts/aep_contractive.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/scripts/aep_contractive.py b/scripts/aep_contractive.py new file mode 100644 index 0000000..670feae --- /dev/null +++ b/scripts/aep_contractive.py @@ -0,0 +1,52 @@ +"""F: make REAL attention EP-able by damping it into a contraction (keep it non-conservative). + +Attention term in the force becomes s*(attn(z) - c*z). The -c*z is damping that grows with s, +pushing Re(eig(J_F)) < 0 (a stable fixed point) WITHOUT symmetrizing the Jacobian (the antisymmetric +part is unchanged, so it stays non-conservative -> AEP still needed AND now applicable). + +We sweep (s, c) and report, using the validated projected-adjoint (option 1): + fwd resid : does a stable fixed point exist? (small = yes) + adj cos : projected-adjoint gradient fidelity vs BPTT on attention params +Expected: c=0 breaks at high s (no fixed point, as before); c>=1 keeps resid small + fidelity high. +""" +import torch, aep_option1 as O +from cet_aep import CETReal +from cet_mvp import token_norm, make_patch_mask, get_loaders + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) +names = [n for n, _ in model.named_parameters()] +orig_attn = model.real_attn # original (undamped) attention + +trl, _ = get_loaders(32, dataset='fashionmnist') +X, _ = next(iter(trl)); X = X.to(dev) +M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) +XBAR = X * (1 - M) +O.X, O.M = X, M # masked_cost in option1 uses these globals + + +def set_damp(c): + model.real_attn = orig_attn if c == 0 else (lambda z: orig_attn(z) - c * z) + + +def resid(s, T1, eps=0.2): + zs, ys = O.relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps) + with torch.enable_grad(): + zr, yr = zs.requires_grad_(True), ys.requires_grad_(True) + fz, _ = O.force(model, XBAR, zr, yr, s) + zn = token_norm(zs + eps * fz.detach()) + return ((zn - zs).norm() / (zs.norm() + 1e-9)).item() + + +print("Contractive (damped) non-conservative attention — does it restore a fixed point + EP fidelity?") +print(f"{'s':>5} {'c':>4} | {'fwd resid':>9} {'adj cos(attn)':>13} {'glob':>7}") +for s in [1.0, 2.0, 4.0, 8.0]: + for c in [0.0, 1.0, 2.0]: + set_damp(c) + r = resid(s, 250) + gb = O.bptt_grad(model, XBAR, s, 250, 0.2) + ga = O.adjoint_grad(model, XBAR, s, 250, 0.2, 250) + a, g = O.cosines(ga, gb, names) + print(f"{s:>5.1f} {c:>4.1f} | {r:>9.2e} {a:>13.3f} {g:>7.3f}") + print() |
