summaryrefslogtreecommitdiff
path: root/ep_run/eig_v2_depth.py
blob: ee597f1e53b0c8073804139a97fc12fb1658fb94 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Path-effect check: is the '|lam|~1.008 at s2000' reading an unconverged-state artifact?
Measure ARPACK leading map-eigenvalues at z after 150 vs 400 vs 800 relax steps (same batch).
If |lam| drops below 1 with depth -> the >1 reading is transient-state contamination and the
fixed-point operator is stable (consistent with eig_probe's Re mu = -0.02 at 250 steps)."""
import numpy as np, torch, scipy.sparse.linalg as sla
from torch.autograd.functional import jvp
import lt_ep_train as L

EPS, B, C = 0.1, 6, 1.0
torch.manual_seed(0)
blk = L.EQBlock(512, 16, 256, 256, c=C, attn_mode='thick'); blk.qknorm = True
ck = torch.load('runs/redx_traj/s2000.pt', map_location=L.dev)
with torch.no_grad():
    for p, w in zip(blk.allp, ck['allp']):
        p.copy_(w.to(L.dev))
torch.manual_seed(42)
idx, _ = L.get_batch('train', B, 256)
xin = blk.embed(idx).detach()
kk = 1.0 - EPS * (1.0 + C)

z = xin.clone()
done = 0
for steps in (150, 400, 800):
    z = L.relax(blk, z, xin, steps - done, EPS); done = steps
    res = (L.relax(blk, z, xin, 1, EPS) - z).norm().item()
    sh, n = z.shape, z.numel()

    def mv(x, z=z):
        v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh)
        with torch.no_grad():
            Mv = kk * v + EPS * jvp(blk.nc_force, z, v)[1]
        return Mv.reshape(-1).double().cpu().numpy()

    A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
    vals = sorted(sla.eigs(A, k=4, which='LM', return_eigenvectors=False, maxiter=2000, tol=1e-4),
                  key=lambda x: -abs(x))
    print(f"[s2000 @ {steps:4d} steps] res={res:.3e}  " +
          "  ".join(f"|l|={abs(l):.5f}({l.real:+.4f}{l.imag:+.4f}j)" for l in vals[:3]), flush=True)
print("EIG_V2_DEPTH_DONE", flush=True)