summaryrefslogtreecommitdiff
path: root/scripts/aep_depth.py
blob: c202a0cd56641fee5656586f6d32d833447c8c59 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""B: does AEP gradient fidelity degrade as the non-conservative attention gets DEEPER?
Stack K residual attention sub-layers (weight-tied) inside the force; measure naive vs
AEP attention-param cosine vs BPTT, at fixed scale s."""
import torch, aep_characterize as A
from cet_aep import CETReal
from cet_mvp import make_patch_mask, get_loaders

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
names = [n for n, _ in model.named_parameters()]
trl, _ = get_loaders(32, dataset='fashionmnist')
X, _ = next(iter(trl)); X = X.to(dev)
M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
A.X, A.M, A.XBAR = X, M, X * (1 - M)

base = model.real_attn
def deep(K):
    def f(z):
        h = z
        for _ in range(K):
            h = h + base(h)
        return h - z
    return f

print(f"{'depth K':>8} | {'naive(attn)':>11} {'AEP(attn)':>10}")
for K in [1, 2, 3, 4]:
    model.real_attn = deep(K)
    r = A.measure(model, names, 1.0, 120, 30, 0.2, 0.02)   # s=1, T2=30 (enough per [3])
    print(f"{K:>8} | {r['naive']:>11.3f} {r['aep']:>10.3f}")