summaryrefslogtreecommitdiff
path: root/scripts/aep_contractive2.py
blob: f1d38a82c261af81729888471a867dce93a4220c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""F (v2): make real attention EP-able via UNCONSTRAINED dynamics + damping (no projection).

The projection (C/F-v1) fought radial damping and broke the VF. Drop it: unconstrained AEP
already has clean theory (0.99 fidelity) but diverges at high s for lack of confinement.
Add damping that scales with s:  attention term = s*(attn(z) - c*z).  Fixed point
z* = [s*attn(z*) + enc]/(4 + s*c) -> attention still sets the direction, but -(4+sc)z makes
it a contraction (so a stable fixed point exists). Small eps needed (the linear part is stiff).

Reuses aep_characterize's UNCONSTRAINED, AEP-validated machinery; monkeypatches attention to the
damped version. Reports naive vs AEP attention-param cosine vs BPTT, and whether it stayed finite.
"""
import math, torch, aep_characterize as A
from cet_aep import CETReal
from cet_mvp import make_patch_mask, get_loaders

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
names = [n for n, _ in model.named_parameters()]
orig = model.real_attn
trl, _ = get_loaders(32, dataset='fashionmnist')
X, _ = next(iter(trl)); X = X.to(dev)
M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
A.X, A.M, A.XBAR = X, M, X * (1 - M)


def setc(c):
    model.real_attn = orig if c == 0 else (lambda z: orig(z) - c * z)


# small eps for the stiff damped linear part; more free steps to converge
EPS, T1, T2, BETA = 0.05, 400, 40, 0.02
print(f"UNCONSTRAINED + damping, eps={EPS} T1={T1} T2={T2}")
print(f"{'s':>5} {'c':>4} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'finite?':>7}")
for s in [2.0, 4.0, 8.0]:
    for c in [0.0, 1.0, 2.0]:
        setc(c)
        r = A.measure(model, names, s, T1, T2, EPS, BETA)
        fin = not (math.isnan(r['aep']) or math.isnan(r['naive']))
        print(f"{s:>5.1f} {c:>4.1f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {str(fin):>7}")
    print()