1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
"""Path-effect check: is the '|lam|~1.008 at s2000' reading an unconverged-state artifact?
Measure ARPACK leading map-eigenvalues at z after 150 vs 400 vs 800 relax steps (same batch).
If |lam| drops below 1 with depth -> the >1 reading is transient-state contamination and the
fixed-point operator is stable (consistent with eig_probe's Re mu = -0.02 at 250 steps)."""
import numpy as np, torch, scipy.sparse.linalg as sla
from torch.autograd.functional import jvp
import lt_ep_train as L
EPS, B, C = 0.1, 6, 1.0
torch.manual_seed(0)
blk = L.EQBlock(512, 16, 256, 256, c=C, attn_mode='thick'); blk.qknorm = True
ck = torch.load('runs/redx_traj/s2000.pt', map_location=L.dev)
with torch.no_grad():
for p, w in zip(blk.allp, ck['allp']):
p.copy_(w.to(L.dev))
torch.manual_seed(42)
idx, _ = L.get_batch('train', B, 256)
xin = blk.embed(idx).detach()
kk = 1.0 - EPS * (1.0 + C)
z = xin.clone()
done = 0
for steps in (150, 400, 800):
z = L.relax(blk, z, xin, steps - done, EPS); done = steps
res = (L.relax(blk, z, xin, 1, EPS) - z).norm().item()
sh, n = z.shape, z.numel()
def mv(x, z=z):
v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh)
with torch.no_grad():
Mv = kk * v + EPS * jvp(blk.nc_force, z, v)[1]
return Mv.reshape(-1).double().cpu().numpy()
A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
vals = sorted(sla.eigs(A, k=4, which='LM', return_eigenvectors=False, maxiter=2000, tol=1e-4),
key=lambda x: -abs(x))
print(f"[s2000 @ {steps:4d} steps] res={res:.3e} " +
" ".join(f"|l|={abs(l):.5f}({l.real:+.4f}{l.imag:+.4f}j)" for l in vals[:3]), flush=True)
print("EIG_V2_DEPTH_DONE", flush=True)
|