summaryrefslogtreecommitdiff
path: root/ep_run/solver_wall.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/solver_wall.py')
-rw-r--r--ep_run/solver_wall.py61
1 files changed, 61 insertions, 0 deletions
diff --git a/ep_run/solver_wall.py b/ep_run/solver_wall.py
new file mode 100644
index 0000000..ee9bf73
--- /dev/null
+++ b/ep_run/solver_wall.py
@@ -0,0 +1,61 @@
+"""Wall-breaking probe. The EP ceiling I measured comes from: rich (thick) block is
+non-contractive -> EP needs heavy damping c to converge the free phase -> damping suppresses
+the very expressivity that made the block good. ESCAPE ROUTE: get convergence from a SOLVER
+(Anderson accel, DEQ-style) instead of from damping. Decisive question: for the THICK block,
+at LOW damping (expressivity intact), can Anderson converge where plain relaxation cannot?
+If yes -> the wall is a solver problem, not fundamental. If no -> the rich block has no fixed
+point to find and the ceiling is intrinsic to the EP/fixed-point requirement."""
+import math, sys, torch
+from lt_ep_train import EQBlock, get_batch
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+B, T, C, H = 16, 64, 128, 4
+eps = 0.05
+
+
+def gmap(blk, z, xin): # relaxation map; fixed point = equilibrium
+ with torch.no_grad():
+ return z + eps * blk.force(z, xin).detach()
+
+
+def plain(blk, z0, xin, steps=200):
+ z = z0.clone()
+ for _ in range(steps):
+ z = gmap(blk, z, xin)
+ return ((gmap(blk, z, xin) - z).norm() / (z.norm() + 1e-9)).item()
+
+
+def anderson(blk, z0, xin, m=6, max_iter=150, tol=1e-6, lam=1e-4):
+ Bs, d = z0.shape[0], z0[0].numel()
+ X = torch.zeros(Bs, m, d, device=dev); Fb = torch.zeros(Bs, m, d, device=dev)
+ X[:, 0] = z0.reshape(Bs, d); Fb[:, 0] = gmap(blk, z0, xin).reshape(Bs, d)
+ X[:, 1] = Fb[:, 0]; Fb[:, 1] = gmap(blk, X[:, 1].view_as(z0), xin).reshape(Bs, d)
+ Hm = torch.zeros(Bs, m + 1, m + 1, device=dev); Hm[:, 0, 1:] = 1; Hm[:, 1:, 0] = 1
+ yv = torch.zeros(Bs, m + 1, 1, device=dev); yv[:, 0] = 1
+ r, k = 1.0, 2
+ for k in range(2, max_iter):
+ n = min(k, m)
+ Gm = Fb[:, :n] - X[:, :n]
+ Hm[:, 1:n + 1, 1:n + 1] = torch.bmm(Gm, Gm.transpose(1, 2)) + lam * torch.eye(n, device=dev)[None]
+ alpha = torch.linalg.solve(Hm[:, :n + 1, :n + 1], yv[:, :n + 1])[:, 1:n + 1, 0]
+ X[:, k % m] = torch.bmm(alpha[:, None], Fb[:, :n])[:, 0]
+ Fb[:, k % m] = gmap(blk, X[:, k % m].view_as(z0), xin).reshape(Bs, d)
+ r = ((Fb[:, k % m] - X[:, k % m]).norm() / (Fb[:, k % m].norm() + 1e-9)).item()
+ if r < tol or not math.isfinite(r):
+ break
+ return r, k + 1
+
+
+for mode in ['real', 'thick']:
+ torch.manual_seed(0)
+ blk = EQBlock(C, H, 256, T, attn_mode=mode)
+ idx, y = get_batch('train', B, T)
+ xin = blk.embed(idx).detach()
+ print(f"\n=== attn_mode={mode} === free-phase convergence: plain relax(200) vs Anderson, eps={eps}")
+ print(f"{'damp c':>7} {'plain_res':>11} {'anderson_res':>13} {'and_iters':>10}")
+ for c in [0.0, 0.25, 0.5, 1.0, 2.0]:
+ blk.c = c
+ pr = plain(blk, xin.clone(), xin)
+ ar, ak = anderson(blk, xin.clone(), xin)
+ flag = ' <- solver converges where plain fails' if (ar < 1e-4 and pr > 1e-2) else ''
+ print(f"{c:>7.2f} {pr:>11.2e} {ar:>13.2e} {ak:>10d}{flag}")