summaryrefslogtreecommitdiff
path: root/scripts/aep_depth.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /scripts/aep_depth.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'scripts/aep_depth.py')
-rw-r--r--scripts/aep_depth.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/scripts/aep_depth.py b/scripts/aep_depth.py
new file mode 100644
index 0000000..c202a0c
--- /dev/null
+++ b/scripts/aep_depth.py
@@ -0,0 +1,30 @@
+"""B: does AEP gradient fidelity degrade as the non-conservative attention gets DEEPER?
+Stack K residual attention sub-layers (weight-tied) inside the force; measure naive vs
+AEP attention-param cosine vs BPTT, at fixed scale s."""
+import torch, aep_characterize as A
+from cet_aep import CETReal
+from cet_mvp import make_patch_mask, get_loaders
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
+names = [n for n, _ in model.named_parameters()]
+trl, _ = get_loaders(32, dataset='fashionmnist')
+X, _ = next(iter(trl)); X = X.to(dev)
+M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
+A.X, A.M, A.XBAR = X, M, X * (1 - M)
+
+base = model.real_attn
+def deep(K):
+ def f(z):
+ h = z
+ for _ in range(K):
+ h = h + base(h)
+ return h - z
+ return f
+
+print(f"{'depth K':>8} | {'naive(attn)':>11} {'AEP(attn)':>10}")
+for K in [1, 2, 3, 4]:
+ model.real_attn = deep(K)
+ r = A.measure(model, names, 1.0, 120, 30, 0.2, 0.02) # s=1, T2=30 (enough per [3])
+ print(f"{K:>8} | {r['naive']:>11.3f} {r['aep']:>10.3f}")