summaryrefslogtreecommitdiff
path: root/ep_run/eig_v2_smoke.py
blob: 241be7d03c2527e5e8deed0b88062c9b02258921 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Validate lead_rho/spec_penalty (eigreg v2, TRUE map-eigenvalue) against ARPACK on the same operator.
Checks, per operator (s2000 warm source + rand-init):
  1. lead_rho's (rho, Re mu) vs scipy eigs LM top eigenvalues of M = I + eps*J_F  (gold standard)
  2. cross-check vs eig_probe's known s2000 number (Re mu ~ -0.02)
  3. spec_penalty fires (non-empty grads, finite) when rho > target, stays off when below
"""
import numpy as np, torch, scipy.sparse.linalg as sla
from torch.autograd.functional import jvp
import lt_ep_train as L
from eig_control import lead_rho, spec_penalty

T1, EPS, B, C = 150, 0.1, 6, 1.0


def load_op(path):
    torch.manual_seed(0)
    blk = L.EQBlock(512, 16, 256, 256, c=C, attn_mode='thick'); blk.qknorm = True
    if path:
        ck = torch.load(path, map_location=L.dev)
        with torch.no_grad():
            for p, w in zip(blk.allp, ck['allp']):
                p.copy_(w.to(L.dev))
    return blk


def arpack_map(blk, z, k=4):
    sh, n = z.shape, z.numel()
    kk = 1.0 - EPS * (1.0 + C)

    def mv(x):
        v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh)
        with torch.no_grad():
            Mv = kk * v + EPS * jvp(blk.nc_force, z, v)[1]
        return Mv.reshape(-1).double().cpu().numpy()

    A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
    vals = sla.eigs(A, k=k, which='LM', return_eigenvectors=False, maxiter=2000, tol=1e-4)
    return sorted(vals, key=lambda x: -abs(x))


for name, path in [('s2000', 'runs/redx_traj/s2000.pt'), ('rand-init', None)]:
    blk = load_op(path)
    torch.manual_seed(42)
    idx, _ = L.get_batch('train', B, 256)
    xin = blk.embed(idx).detach()
    zs = L.relax(blk, xin.clone(), xin, T1, EPS)
    _, rho, mu = lead_rho(blk, zs, EPS, C, {}, iters=40)
    top = arpack_map(blk, zs)
    lam0 = top[0]
    print(f"[{name}] lead_rho: rho={rho:.5f} Re_mu={mu:+.4f}", flush=True)
    print(f"[{name}] ARPACK  : " + "  ".join(f"|l|={abs(l):.5f}({l.real:+.4f}{l.imag:+.4f}j)" for l in top[:3]))
    print(f"[{name}] agree   : d_rho={abs(rho - abs(lam0)):.4f} d_Remu={abs(mu - (lam0.real - 1) / EPS):.4f}")
    g, r0, m0 = spec_penalty(blk, zs, EPS, C, 0.1, 0.995, {}, iters=40)
    fin = all(torch.isfinite(v).all().item() for v in g.values()) if g else True
    print(f"[{name}] penalty : rho={r0:.5f} fired={bool(g)} n_grads={len(g)} finite={fin}\n", flush=True)
print("EIG_V2_SMOKE_DONE", flush=True)