"""Validate lead_rho/spec_penalty (eigreg v2, TRUE map-eigenvalue) against ARPACK on the same operator. Checks, per operator (s2000 warm source + rand-init): 1. lead_rho's (rho, Re mu) vs scipy eigs LM top eigenvalues of M = I + eps*J_F (gold standard) 2. cross-check vs eig_probe's known s2000 number (Re mu ~ -0.02) 3. spec_penalty fires (non-empty grads, finite) when rho > target, stays off when below """ import numpy as np, torch, scipy.sparse.linalg as sla from torch.autograd.functional import jvp import lt_ep_train as L from eig_control import lead_rho, spec_penalty T1, EPS, B, C = 150, 0.1, 6, 1.0 def load_op(path): torch.manual_seed(0) blk = L.EQBlock(512, 16, 256, 256, c=C, attn_mode='thick'); blk.qknorm = True if path: ck = torch.load(path, map_location=L.dev) with torch.no_grad(): for p, w in zip(blk.allp, ck['allp']): p.copy_(w.to(L.dev)) return blk def arpack_map(blk, z, k=4): sh, n = z.shape, z.numel() kk = 1.0 - EPS * (1.0 + C) def mv(x): v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh) with torch.no_grad(): Mv = kk * v + EPS * jvp(blk.nc_force, z, v)[1] return Mv.reshape(-1).double().cpu().numpy() A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64) vals = sla.eigs(A, k=k, which='LM', return_eigenvectors=False, maxiter=2000, tol=1e-4) return sorted(vals, key=lambda x: -abs(x)) for name, path in [('s2000', 'runs/redx_traj/s2000.pt'), ('rand-init', None)]: blk = load_op(path) torch.manual_seed(42) idx, _ = L.get_batch('train', B, 256) xin = blk.embed(idx).detach() zs = L.relax(blk, xin.clone(), xin, T1, EPS) _, rho, mu = lead_rho(blk, zs, EPS, C, {}, iters=40) top = arpack_map(blk, zs) lam0 = top[0] print(f"[{name}] lead_rho: rho={rho:.5f} Re_mu={mu:+.4f}", flush=True) print(f"[{name}] ARPACK : " + " ".join(f"|l|={abs(l):.5f}({l.real:+.4f}{l.imag:+.4f}j)" for l in top[:3])) print(f"[{name}] agree : d_rho={abs(rho - abs(lam0)):.4f} d_Remu={abs(mu - (lam0.real - 1) / EPS):.4f}") g, r0, m0 = spec_penalty(blk, zs, EPS, C, 0.1, 0.995, {}, iters=40) fin = all(torch.isfinite(v).all().item() for v in g.values()) if g else True print(f"[{name}] penalty : rho={r0:.5f} fired={bool(g)} n_grads={len(g)} finite={fin}\n", flush=True) print("EIG_V2_SMOKE_DONE", flush=True)