diff options
Diffstat (limited to 'scripts/aep_depth.py')
| -rw-r--r-- | scripts/aep_depth.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/scripts/aep_depth.py b/scripts/aep_depth.py new file mode 100644 index 0000000..c202a0c --- /dev/null +++ b/scripts/aep_depth.py @@ -0,0 +1,30 @@ +"""B: does AEP gradient fidelity degrade as the non-conservative attention gets DEEPER? +Stack K residual attention sub-layers (weight-tied) inside the force; measure naive vs +AEP attention-param cosine vs BPTT, at fixed scale s.""" +import torch, aep_characterize as A +from cet_aep import CETReal +from cet_mvp import make_patch_mask, get_loaders + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) +names = [n for n, _ in model.named_parameters()] +trl, _ = get_loaders(32, dataset='fashionmnist') +X, _ = next(iter(trl)); X = X.to(dev) +M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) +A.X, A.M, A.XBAR = X, M, X * (1 - M) + +base = model.real_attn +def deep(K): + def f(z): + h = z + for _ in range(K): + h = h + base(h) + return h - z + return f + +print(f"{'depth K':>8} | {'naive(attn)':>11} {'AEP(attn)':>10}") +for K in [1, 2, 3, 4]: + model.real_attn = deep(K) + r = A.measure(model, names, 1.0, 120, 30, 0.2, 0.02) # s=1, T2=30 (enough per [3]) + print(f"{K:>8} | {r['naive']:>11.3f} {r['aep']:>10.3f}") |
