1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
"""F: make REAL attention EP-able by damping it into a contraction (keep it non-conservative).
Attention term in the force becomes s*(attn(z) - c*z). The -c*z is damping that grows with s,
pushing Re(eig(J_F)) < 0 (a stable fixed point) WITHOUT symmetrizing the Jacobian (the antisymmetric
part is unchanged, so it stays non-conservative -> AEP still needed AND now applicable).
We sweep (s, c) and report, using the validated projected-adjoint (option 1):
fwd resid : does a stable fixed point exist? (small = yes)
adj cos : projected-adjoint gradient fidelity vs BPTT on attention params
Expected: c=0 breaks at high s (no fixed point, as before); c>=1 keeps resid small + fidelity high.
"""
import torch, aep_option1 as O
from cet_aep import CETReal
from cet_mvp import token_norm, make_patch_mask, get_loaders
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
names = [n for n, _ in model.named_parameters()]
orig_attn = model.real_attn # original (undamped) attention
trl, _ = get_loaders(32, dataset='fashionmnist')
X, _ = next(iter(trl)); X = X.to(dev)
M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
XBAR = X * (1 - M)
O.X, O.M = X, M # masked_cost in option1 uses these globals
def set_damp(c):
model.real_attn = orig_attn if c == 0 else (lambda z: orig_attn(z) - c * z)
def resid(s, T1, eps=0.2):
zs, ys = O.relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps)
with torch.enable_grad():
zr, yr = zs.requires_grad_(True), ys.requires_grad_(True)
fz, _ = O.force(model, XBAR, zr, yr, s)
zn = token_norm(zs + eps * fz.detach())
return ((zn - zs).norm() / (zs.norm() + 1e-9)).item()
print("Contractive (damped) non-conservative attention — does it restore a fixed point + EP fidelity?")
print(f"{'s':>5} {'c':>4} | {'fwd resid':>9} {'adj cos(attn)':>13} {'glob':>7}")
for s in [1.0, 2.0, 4.0, 8.0]:
for c in [0.0, 1.0, 2.0]:
set_damp(c)
r = resid(s, 250)
gb = O.bptt_grad(model, XBAR, s, 250, 0.2)
ga = O.adjoint_grad(model, XBAR, s, 250, 0.2, 250)
a, g = O.cosines(ga, gb, names)
print(f"{s:>5.1f} {c:>4.1f} | {r:>9.2e} {a:>13.3f} {g:>7.3f}")
print()
|