From b83947778e2c776f757a07d4719b7ce961d7ed55 Mon Sep 17 00:00:00 2001 From: Yuren Hao Date: Fri, 3 Jul 2026 05:56:50 -0500 Subject: =?UTF-8?q?Initial=20commit:=20ept=20=E2=80=94=20backprop-free=20e?= =?UTF-8?q?quilibrium=20transformer=20(EP)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn --- scripts/aep_contractive2.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 scripts/aep_contractive2.py (limited to 'scripts/aep_contractive2.py') diff --git a/scripts/aep_contractive2.py b/scripts/aep_contractive2.py new file mode 100644 index 0000000..f1d38a8 --- /dev/null +++ b/scripts/aep_contractive2.py @@ -0,0 +1,41 @@ +"""F (v2): make real attention EP-able via UNCONSTRAINED dynamics + damping (no projection). + +The projection (C/F-v1) fought radial damping and broke the VF. Drop it: unconstrained AEP +already has clean theory (0.99 fidelity) but diverges at high s for lack of confinement. +Add damping that scales with s: attention term = s*(attn(z) - c*z). Fixed point +z* = [s*attn(z*) + enc]/(4 + s*c) -> attention still sets the direction, but -(4+sc)z makes +it a contraction (so a stable fixed point exists). Small eps needed (the linear part is stiff). + +Reuses aep_characterize's UNCONSTRAINED, AEP-validated machinery; monkeypatches attention to the +damped version. Reports naive vs AEP attention-param cosine vs BPTT, and whether it stayed finite. +""" +import math, torch, aep_characterize as A +from cet_aep import CETReal +from cet_mvp import make_patch_mask, get_loaders + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) +names = [n for n, _ in model.named_parameters()] +orig = model.real_attn +trl, _ = get_loaders(32, dataset='fashionmnist') +X, _ = next(iter(trl)); X = X.to(dev) +M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) +A.X, A.M, A.XBAR = X, M, X * (1 - M) + + +def setc(c): + model.real_attn = orig if c == 0 else (lambda z: orig(z) - c * z) + + +# small eps for the stiff damped linear part; more free steps to converge +EPS, T1, T2, BETA = 0.05, 400, 40, 0.02 +print(f"UNCONSTRAINED + damping, eps={EPS} T1={T1} T2={T2}") +print(f"{'s':>5} {'c':>4} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'finite?':>7}") +for s in [2.0, 4.0, 8.0]: + for c in [0.0, 1.0, 2.0]: + setc(c) + r = A.measure(model, names, s, T1, T2, EPS, BETA) + fin = not (math.isnan(r['aep']) or math.isnan(r['naive'])) + print(f"{s:>5.1f} {c:>4.1f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {str(fin):>7}") + print() -- cgit v1.2.3