summaryrefslogtreecommitdiff
path: root/ep_run/eig_v2_depth.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/eig_v2_depth.py')
-rw-r--r--ep_run/eig_v2_depth.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/ep_run/eig_v2_depth.py b/ep_run/eig_v2_depth.py
new file mode 100644
index 0000000..ee597f1
--- /dev/null
+++ b/ep_run/eig_v2_depth.py
@@ -0,0 +1,39 @@
+"""Path-effect check: is the '|lam|~1.008 at s2000' reading an unconverged-state artifact?
+Measure ARPACK leading map-eigenvalues at z after 150 vs 400 vs 800 relax steps (same batch).
+If |lam| drops below 1 with depth -> the >1 reading is transient-state contamination and the
+fixed-point operator is stable (consistent with eig_probe's Re mu = -0.02 at 250 steps)."""
+import numpy as np, torch, scipy.sparse.linalg as sla
+from torch.autograd.functional import jvp
+import lt_ep_train as L
+
+EPS, B, C = 0.1, 6, 1.0
+torch.manual_seed(0)
+blk = L.EQBlock(512, 16, 256, 256, c=C, attn_mode='thick'); blk.qknorm = True
+ck = torch.load('runs/redx_traj/s2000.pt', map_location=L.dev)
+with torch.no_grad():
+ for p, w in zip(blk.allp, ck['allp']):
+ p.copy_(w.to(L.dev))
+torch.manual_seed(42)
+idx, _ = L.get_batch('train', B, 256)
+xin = blk.embed(idx).detach()
+kk = 1.0 - EPS * (1.0 + C)
+
+z = xin.clone()
+done = 0
+for steps in (150, 400, 800):
+ z = L.relax(blk, z, xin, steps - done, EPS); done = steps
+ res = (L.relax(blk, z, xin, 1, EPS) - z).norm().item()
+ sh, n = z.shape, z.numel()
+
+ def mv(x, z=z):
+ v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh)
+ with torch.no_grad():
+ Mv = kk * v + EPS * jvp(blk.nc_force, z, v)[1]
+ return Mv.reshape(-1).double().cpu().numpy()
+
+ A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
+ vals = sorted(sla.eigs(A, k=4, which='LM', return_eigenvectors=False, maxiter=2000, tol=1e-4),
+ key=lambda x: -abs(x))
+ print(f"[s2000 @ {steps:4d} steps] res={res:.3e} " +
+ " ".join(f"|l|={abs(l):.5f}({l.real:+.4f}{l.imag:+.4f}j)" for l in vals[:3]), flush=True)
+print("EIG_V2_DEPTH_DONE", flush=True)