1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
"""Validate lead_rho/spec_penalty (eigreg v2, TRUE map-eigenvalue) against ARPACK on the same operator.
Checks, per operator (s2000 warm source + rand-init):
1. lead_rho's (rho, Re mu) vs scipy eigs LM top eigenvalues of M = I + eps*J_F (gold standard)
2. cross-check vs eig_probe's known s2000 number (Re mu ~ -0.02)
3. spec_penalty fires (non-empty grads, finite) when rho > target, stays off when below
"""
import numpy as np, torch, scipy.sparse.linalg as sla
from torch.autograd.functional import jvp
import lt_ep_train as L
from eig_control import lead_rho, spec_penalty
T1, EPS, B, C = 150, 0.1, 6, 1.0
def load_op(path):
torch.manual_seed(0)
blk = L.EQBlock(512, 16, 256, 256, c=C, attn_mode='thick'); blk.qknorm = True
if path:
ck = torch.load(path, map_location=L.dev)
with torch.no_grad():
for p, w in zip(blk.allp, ck['allp']):
p.copy_(w.to(L.dev))
return blk
def arpack_map(blk, z, k=4):
sh, n = z.shape, z.numel()
kk = 1.0 - EPS * (1.0 + C)
def mv(x):
v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh)
with torch.no_grad():
Mv = kk * v + EPS * jvp(blk.nc_force, z, v)[1]
return Mv.reshape(-1).double().cpu().numpy()
A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
vals = sla.eigs(A, k=k, which='LM', return_eigenvectors=False, maxiter=2000, tol=1e-4)
return sorted(vals, key=lambda x: -abs(x))
for name, path in [('s2000', 'runs/redx_traj/s2000.pt'), ('rand-init', None)]:
blk = load_op(path)
torch.manual_seed(42)
idx, _ = L.get_batch('train', B, 256)
xin = blk.embed(idx).detach()
zs = L.relax(blk, xin.clone(), xin, T1, EPS)
_, rho, mu = lead_rho(blk, zs, EPS, C, {}, iters=40)
top = arpack_map(blk, zs)
lam0 = top[0]
print(f"[{name}] lead_rho: rho={rho:.5f} Re_mu={mu:+.4f}", flush=True)
print(f"[{name}] ARPACK : " + " ".join(f"|l|={abs(l):.5f}({l.real:+.4f}{l.imag:+.4f}j)" for l in top[:3]))
print(f"[{name}] agree : d_rho={abs(rho - abs(lam0)):.4f} d_Remu={abs(mu - (lam0.real - 1) / EPS):.4f}")
g, r0, m0 = spec_penalty(blk, zs, EPS, C, 0.1, 0.995, {}, iters=40)
fin = all(torch.isfinite(v).all().item() for v in g.values()) if g else True
print(f"[{name}] penalty : rho={r0:.5f} fired={bool(g)} n_grads={len(g)} finite={fin}\n", flush=True)
print("EIG_V2_SMOKE_DONE", flush=True)
|