summaryrefslogtreecommitdiff
path: root/scripts/aep_contractive.py
blob: 670feae132e61c04bedc5606a96e9d7ccb47db46 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""F: make REAL attention EP-able by damping it into a contraction (keep it non-conservative).

Attention term in the force becomes  s*(attn(z) - c*z).  The -c*z is damping that grows with s,
pushing Re(eig(J_F)) < 0 (a stable fixed point) WITHOUT symmetrizing the Jacobian (the antisymmetric
part is unchanged, so it stays non-conservative -> AEP still needed AND now applicable).

We sweep (s, c) and report, using the validated projected-adjoint (option 1):
  fwd resid : does a stable fixed point exist?  (small = yes)
  adj cos   : projected-adjoint gradient fidelity vs BPTT on attention params
Expected: c=0 breaks at high s (no fixed point, as before); c>=1 keeps resid small + fidelity high.
"""
import torch, aep_option1 as O
from cet_aep import CETReal
from cet_mvp import token_norm, make_patch_mask, get_loaders

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
names = [n for n, _ in model.named_parameters()]
orig_attn = model.real_attn                       # original (undamped) attention

trl, _ = get_loaders(32, dataset='fashionmnist')
X, _ = next(iter(trl)); X = X.to(dev)
M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
XBAR = X * (1 - M)
O.X, O.M = X, M                                   # masked_cost in option1 uses these globals


def set_damp(c):
    model.real_attn = orig_attn if c == 0 else (lambda z: orig_attn(z) - c * z)


def resid(s, T1, eps=0.2):
    zs, ys = O.relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps)
    with torch.enable_grad():
        zr, yr = zs.requires_grad_(True), ys.requires_grad_(True)
        fz, _ = O.force(model, XBAR, zr, yr, s)
    zn = token_norm(zs + eps * fz.detach())
    return ((zn - zs).norm() / (zs.norm() + 1e-9)).item()


print("Contractive (damped) non-conservative attention — does it restore a fixed point + EP fidelity?")
print(f"{'s':>5} {'c':>4} | {'fwd resid':>9} {'adj cos(attn)':>13} {'glob':>7}")
for s in [1.0, 2.0, 4.0, 8.0]:
    for c in [0.0, 1.0, 2.0]:
        set_damp(c)
        r = resid(s, 250)
        gb = O.bptt_grad(model, XBAR, s, 250, 0.2)
        ga = O.adjoint_grad(model, XBAR, s, 250, 0.2, 250)
        a, g = O.cosines(ga, gb, names)
        print(f"{s:>5.1f} {c:>4.1f} | {r:>9.2e} {a:>13.3f} {g:>7.3f}")
    print()