diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/lt_ep_anderson.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/lt_ep_anderson.py')
| -rw-r--r-- | ep_run/lt_ep_anderson.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/ep_run/lt_ep_anderson.py b/ep_run/lt_ep_anderson.py new file mode 100644 index 0000000..7682c50 --- /dev/null +++ b/ep_run/lt_ep_anderson.py @@ -0,0 +1,54 @@ +"""Decisive test for the Anderson idea: at LOW damping (expressive attention), can a fixed-point +SOLVER (Anderson acceleration, DEQ-style) converge the free phase where plain fixed-step relaxation +cannot? If yes -> we get convergence from the solver, not from suppressing attention with damping.""" +import math, torch +from lt_ep_train import EQBlock, get_batch +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +B, T, C, H = 16, 64, 128, 4 +blk = EQBlock(C, H, 256, T, attn_mode='real') +idx, y = get_batch('train', B, T) +xin = blk.embed(idx).detach() +eps = 0.05 + + +def gmap(z): # relaxation map; its fixed point = the equilibrium + with torch.no_grad(): + return z + eps * blk.force(z, xin).detach() + + +def plain(z0, steps=200): + z = z0.clone() + for _ in range(steps): + z = gmap(z) + return ((gmap(z) - z).norm() / (z.norm() + 1e-9)).item() + + +def anderson(z0, m=6, max_iter=120, tol=1e-6, lam=1e-4): + Bs, d = z0.shape[0], z0[0].numel() + X = torch.zeros(Bs, m, d, device=dev); Fb = torch.zeros(Bs, m, d, device=dev) + X[:, 0] = z0.reshape(Bs, d); Fb[:, 0] = gmap(z0).reshape(Bs, d) + X[:, 1] = Fb[:, 0]; Fb[:, 1] = gmap(X[:, 1].view_as(z0)).reshape(Bs, d) + Hm = torch.zeros(Bs, m + 1, m + 1, device=dev); Hm[:, 0, 1:] = 1; Hm[:, 1:, 0] = 1 + yv = torch.zeros(Bs, m + 1, 1, device=dev); yv[:, 0] = 1 + r, k = 1.0, 2 + for k in range(2, max_iter): + n = min(k, m) + Gm = Fb[:, :n] - X[:, :n] + Hm[:, 1:n + 1, 1:n + 1] = torch.bmm(Gm, Gm.transpose(1, 2)) + lam * torch.eye(n, device=dev)[None] + alpha = torch.linalg.solve(Hm[:, :n + 1, :n + 1], yv[:, :n + 1])[:, 1:n + 1, 0] + X[:, k % m] = torch.bmm(alpha[:, None], Fb[:, :n])[:, 0] + Fb[:, k % m] = gmap(X[:, k % m].view_as(z0)).reshape(Bs, d) + r = ((Fb[:, k % m] - X[:, k % m]).norm() / (Fb[:, k % m].norm() + 1e-9)).item() + if r < tol or not math.isfinite(r): + break + return r, k + 1 + + +print("free-phase convergence: plain relax (200 steps) vs Anderson — real attention, eps=0.05") +print(f"{'damp c':>7} {'plain_res':>11} {'anderson_res':>13} {'and_iters':>10}") +for c in [0.0, 0.25, 0.5, 1.0, 2.0, 4.0]: + blk.c = c + pr = plain(xin.clone()) + ar, ak = anderson(xin.clone()) + print(f"{c:>7.2f} {pr:>11.2e} {ar:>13.2e} {ak:>10d}") |
