"""F: make REAL attention EP-able by damping it into a contraction (keep it non-conservative). Attention term in the force becomes s*(attn(z) - c*z). The -c*z is damping that grows with s, pushing Re(eig(J_F)) < 0 (a stable fixed point) WITHOUT symmetrizing the Jacobian (the antisymmetric part is unchanged, so it stays non-conservative -> AEP still needed AND now applicable). We sweep (s, c) and report, using the validated projected-adjoint (option 1): fwd resid : does a stable fixed point exist? (small = yes) adj cos : projected-adjoint gradient fidelity vs BPTT on attention params Expected: c=0 breaks at high s (no fixed point, as before); c>=1 keeps resid small + fidelity high. """ import torch, aep_option1 as O from cet_aep import CETReal from cet_mvp import token_norm, make_patch_mask, get_loaders dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) names = [n for n, _ in model.named_parameters()] orig_attn = model.real_attn # original (undamped) attention trl, _ = get_loaders(32, dataset='fashionmnist') X, _ = next(iter(trl)); X = X.to(dev) M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) XBAR = X * (1 - M) O.X, O.M = X, M # masked_cost in option1 uses these globals def set_damp(c): model.real_attn = orig_attn if c == 0 else (lambda z: orig_attn(z) - c * z) def resid(s, T1, eps=0.2): zs, ys = O.relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps) with torch.enable_grad(): zr, yr = zs.requires_grad_(True), ys.requires_grad_(True) fz, _ = O.force(model, XBAR, zr, yr, s) zn = token_norm(zs + eps * fz.detach()) return ((zn - zs).norm() / (zs.norm() + 1e-9)).item() print("Contractive (damped) non-conservative attention — does it restore a fixed point + EP fidelity?") print(f"{'s':>5} {'c':>4} | {'fwd resid':>9} {'adj cos(attn)':>13} {'glob':>7}") for s in [1.0, 2.0, 4.0, 8.0]: for c in [0.0, 1.0, 2.0]: set_damp(c) r = resid(s, 250) gb = O.bptt_grad(model, XBAR, s, 250, 0.2) ga = O.adjoint_grad(model, XBAR, s, 250, 0.2, 250) a, g = O.cosines(ga, gb, names) print(f"{s:>5.1f} {c:>4.1f} | {r:>9.2e} {a:>13.3f} {g:>7.3f}") print()