1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
"""Decisive test for the Anderson idea: at LOW damping (expressive attention), can a fixed-point
SOLVER (Anderson acceleration, DEQ-style) converge the free phase where plain fixed-step relaxation
cannot? If yes -> we get convergence from the solver, not from suppressing attention with damping."""
import math, torch
from lt_ep_train import EQBlock, get_batch
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
B, T, C, H = 16, 64, 128, 4
blk = EQBlock(C, H, 256, T, attn_mode='real')
idx, y = get_batch('train', B, T)
xin = blk.embed(idx).detach()
eps = 0.05
def gmap(z): # relaxation map; its fixed point = the equilibrium
with torch.no_grad():
return z + eps * blk.force(z, xin).detach()
def plain(z0, steps=200):
z = z0.clone()
for _ in range(steps):
z = gmap(z)
return ((gmap(z) - z).norm() / (z.norm() + 1e-9)).item()
def anderson(z0, m=6, max_iter=120, tol=1e-6, lam=1e-4):
Bs, d = z0.shape[0], z0[0].numel()
X = torch.zeros(Bs, m, d, device=dev); Fb = torch.zeros(Bs, m, d, device=dev)
X[:, 0] = z0.reshape(Bs, d); Fb[:, 0] = gmap(z0).reshape(Bs, d)
X[:, 1] = Fb[:, 0]; Fb[:, 1] = gmap(X[:, 1].view_as(z0)).reshape(Bs, d)
Hm = torch.zeros(Bs, m + 1, m + 1, device=dev); Hm[:, 0, 1:] = 1; Hm[:, 1:, 0] = 1
yv = torch.zeros(Bs, m + 1, 1, device=dev); yv[:, 0] = 1
r, k = 1.0, 2
for k in range(2, max_iter):
n = min(k, m)
Gm = Fb[:, :n] - X[:, :n]
Hm[:, 1:n + 1, 1:n + 1] = torch.bmm(Gm, Gm.transpose(1, 2)) + lam * torch.eye(n, device=dev)[None]
alpha = torch.linalg.solve(Hm[:, :n + 1, :n + 1], yv[:, :n + 1])[:, 1:n + 1, 0]
X[:, k % m] = torch.bmm(alpha[:, None], Fb[:, :n])[:, 0]
Fb[:, k % m] = gmap(X[:, k % m].view_as(z0)).reshape(Bs, d)
r = ((Fb[:, k % m] - X[:, k % m]).norm() / (Fb[:, k % m].norm() + 1e-9)).item()
if r < tol or not math.isfinite(r):
break
return r, k + 1
print("free-phase convergence: plain relax (200 steps) vs Anderson — real attention, eps=0.05")
print(f"{'damp c':>7} {'plain_res':>11} {'anderson_res':>13} {'and_iters':>10}")
for c in [0.0, 0.25, 0.5, 1.0, 2.0, 4.0]:
blk.c = c
pr = plain(xin.clone())
ar, ak = anderson(xin.clone())
print(f"{c:>7.2f} {pr:>11.2e} {ar:>13.2e} {ak:>10d}")
|