summaryrefslogtreecommitdiff
path: root/ep_run/eig_recheck.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 07:57:22 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 07:57:22 -0500
commitbcec9560cf5c9b113e9381a52d1a941daa8865f2 (patch)
treebae3baf6d742b816d90e642d70b9744a86a4d189 /ep_run/eig_recheck.py
parentc0b507fb1760be291e1e1ed33f33fb18f16d8c2d (diff)
omega/norm-family refuted as stability signal; fingerprint story retracted; eigreg v2 = true map-eigenvalue (spec_penalty)HEADmaster
- eig_control: fix plain-PI bug (shifted PI for lambda_max of indefinite Sym); add lead_rho + spec_penalty (soft one-sided cap on |lam|(I+eps*J_F), 2-D Rayleigh-Ritz, matvec-only) — aep 'spectral' ported. eig_penalty demoted to diagnostic. - eig_recheck.py (Lanczos audit): omega=+5..+13 on ALL operators incl the stablest (s2000 +12.8 while true alpha=-0.02); gap omega-alpha~10; old 'warm -10.14 vs scratch +1.11' numbers were PI-mixture artifacts. RETRACTED. - eig_v2_smoke/depth: v2 mechanics validated vs ARPACK; z_T1 readings >1 are unconverged-state contamination (150: 1.009 -> 400/800: 0.997-0.999, mu=-0.02..-0.006 matching eig_probe); fixed-point top = BAND of slow modes. - lt_ep_train: --eigreg now spec_penalty (--eig_margin 0.995 = rho target); --fingerprint reports rho/Re_mu instead of num_abscissa. - ONBOARDING §4-7 + FINDINGS 2026-07-03: retraction + verdict (fundamental quantity = finite-horizon path LE / resreg axis; de-cliff via floss-ept; spec_penalty = measure-mode scalpel for a detaching Hopf pair). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/eig_recheck.py')
-rw-r--r--ep_run/eig_recheck.py109
1 files changed, 109 insertions, 0 deletions
diff --git a/ep_run/eig_recheck.py b/ep_run/eig_recheck.py
new file mode 100644
index 0000000..317a40d
--- /dev/null
+++ b/ep_run/eig_recheck.py
@@ -0,0 +1,109 @@
+"""eig_recheck.py — gold-standard recheck of the warm/scratch abscissa story after the num_abscissa
+shift fix. Plain power iteration (the pre-fix estimator) converges to the largest-|lambda| end of the
+indefinite Sym(J_nc) — so the reported 's2000 = -10.14 vs scratch = +1.11' may have compared lambda_min
+against lambda_max. For each operator this measures, at z* on the SAME seeded batches:
+ lam_max / lam_min — both ENDS of Sym(J_nc), Lanczos (scipy eigsh LA/SA, matvec-only, gold standard)
+ oldPI — what the pre-fix plain power iteration returned (bug reproduction)
+ newPI — the fixed shifted-PI estimate (validates the training-time estimator vs Lanczos)
+ res — T1-residual eps*||F(z*)|| (the lt_ep_train line-455 convention)
+ val — val CE (identifies the checkpoint on the training curve)
+Verdict logic: oldPI ~= lam_min (when |lam_min| > lam_max) ==> the bug was live and the -10 story is
+about the WRONG end; the warm/scratch contrast must be re-read off the lam_max column.
+"""
+import numpy as np, torch, scipy.sparse.linalg as sla
+from pathlib import Path
+from torch.autograd.functional import jvp, vjp
+import lt_ep_train as L
+from eig_control import num_abscissa
+
+T1, EPS, B, NB = 150, 0.1, 6, 2
+
+CKPTS = [ # (label, path or None=random init)
+ ('rand-init', None),
+ ('s2000-warmsrc', 'runs/redx_traj/s2000.pt'),
+ ('resreg-scratch', 'runs/ep_resreg_scratch.pt'),
+ ('fast-adaptive', 'runs/ep_fast_adaptive.pt'),
+ ('warm-fast-live', 'runs/ep_warm_fast.pt'),
+ ('self-restart-live', 'runs/ep_self_restart.pt'),
+]
+
+
+def load_op(path):
+ torch.manual_seed(0)
+ blk = L.EQBlock(512, 16, 256, 256, c=1.0, attn_mode='thick')
+ blk.qknorm = True # canonical C512 recipe flag (not stored in ckpt)
+ if path is None:
+ return blk, 'random init'
+ for attempt in range(2): # live ckpts: retry once in case of a mid-save read
+ try:
+ ck = torch.load(path, map_location=L.dev); break
+ except Exception:
+ if attempt: raise
+ import time; time.sleep(5)
+ with torch.no_grad():
+ for p, w in zip(blk.allp, ck['allp']):
+ p.copy_(w.to(L.dev))
+ return blk, f"step {ck.get('step')} best {ck.get('best', float('nan')):.4f}"
+
+
+def sym_ends(blk, z): # both ends of Sym(J_nc) via Lanczos
+ sh, n = z.shape, z.numel()
+
+ def mv(x):
+ v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh)
+ with torch.no_grad():
+ Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1])
+ return Sv.reshape(-1).double().cpu().numpy()
+
+ A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
+ kw = dict(k=1, tol=1e-3, return_eigenvectors=False, maxiter=600)
+ return float(sla.eigsh(A, which='LA', **kw)[0]), float(sla.eigsh(A, which='SA', **kw)[0])
+
+
+def old_pi(blk, z, iters=3): # the PRE-FIX estimator, reproduced verbatim
+ torch.manual_seed(1)
+ v = torch.randn_like(z); v = v / (v.norm() + 1e-12)
+ with torch.no_grad():
+ for _ in range(iters):
+ Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1])
+ v = Sv / (Sv.norm() + 1e-12)
+ return float((v * jvp(blk.nc_force, z, v)[1]).sum() / (v * v).sum())
+
+
+def main():
+ rows = []
+ for name, path in CKPTS:
+ if path and not Path(path).exists():
+ print(f"[skip] {name}: {path} missing", flush=True); continue
+ blk, info = load_op(path)
+ torch.manual_seed(42) # SAME batches for every operator
+ acc = {k: [] for k in ('la', 'sa', 'op', 'np', 'res')}
+ for b in range(NB):
+ idx, _ = L.get_batch('train', B, 256)
+ xin = blk.embed(idx).detach()
+ zs = L.relax(blk, xin.clone(), xin, T1, EPS)
+ res = (L.relax(blk, zs, xin, 1, EPS) - zs).norm().item()
+ la, sa = sym_ends(blk, zs)
+ opi = old_pi(blk, zs)
+ _, npi = num_abscissa(blk, zs, {}, iters=40)
+ for k, x in zip(('la', 'sa', 'op', 'np', 'res'), (la, sa, opi, npi, res)):
+ acc[k].append(x)
+ print(f" [{name} b{b}] lam_max={la:+8.3f} lam_min={sa:+8.3f} oldPI={opi:+8.3f} "
+ f"newPI={npi:+8.3f} res={res:.2e}", flush=True)
+ val = L.evaluate(blk, T1, EPS)
+ m = {k: sum(v) / len(v) for k, v in acc.items()}
+ rows.append((name, info, m, val))
+ wrong_end = abs(m['op'] - m['sa']) < abs(m['op'] - m['la'])
+ print(f"[{name}] ({info}) lam_max={m['la']:+.3f} lam_min={m['sa']:+.3f} oldPI={m['op']:+.3f}"
+ f"{' <-- oldPI tracked lam_MIN (bug live)' if wrong_end else ' (oldPI ~ lam_max, bug latent)'}"
+ f" newPI={m['np']:+.3f} res={m['res']:.2e} val={val:.4f}\n", flush=True)
+
+ print(f"{'operator':<20}{'lam_max':>9}{'lam_min':>9}{'oldPI':>9}{'newPI':>9}{'res':>10}{'val':>8}")
+ for name, info, m, val in rows:
+ print(f"{name:<20}{m['la']:>+9.3f}{m['sa']:>+9.3f}{m['op']:>+9.3f}{m['np']:>+9.3f}"
+ f"{m['res']:>10.2e}{val:>8.4f} # {info}")
+ print("EIG_RECHECK_DONE", flush=True)
+
+
+if __name__ == '__main__':
+ main()