diff options
Diffstat (limited to 'ep_run/eig_recheck.py')
| -rw-r--r-- | ep_run/eig_recheck.py | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/ep_run/eig_recheck.py b/ep_run/eig_recheck.py new file mode 100644 index 0000000..317a40d --- /dev/null +++ b/ep_run/eig_recheck.py @@ -0,0 +1,109 @@ +"""eig_recheck.py — gold-standard recheck of the warm/scratch abscissa story after the num_abscissa +shift fix. Plain power iteration (the pre-fix estimator) converges to the largest-|lambda| end of the +indefinite Sym(J_nc) — so the reported 's2000 = -10.14 vs scratch = +1.11' may have compared lambda_min +against lambda_max. For each operator this measures, at z* on the SAME seeded batches: + lam_max / lam_min — both ENDS of Sym(J_nc), Lanczos (scipy eigsh LA/SA, matvec-only, gold standard) + oldPI — what the pre-fix plain power iteration returned (bug reproduction) + newPI — the fixed shifted-PI estimate (validates the training-time estimator vs Lanczos) + res — T1-residual eps*||F(z*)|| (the lt_ep_train line-455 convention) + val — val CE (identifies the checkpoint on the training curve) +Verdict logic: oldPI ~= lam_min (when |lam_min| > lam_max) ==> the bug was live and the -10 story is +about the WRONG end; the warm/scratch contrast must be re-read off the lam_max column. +""" +import numpy as np, torch, scipy.sparse.linalg as sla +from pathlib import Path +from torch.autograd.functional import jvp, vjp +import lt_ep_train as L +from eig_control import num_abscissa + +T1, EPS, B, NB = 150, 0.1, 6, 2 + +CKPTS = [ # (label, path or None=random init) + ('rand-init', None), + ('s2000-warmsrc', 'runs/redx_traj/s2000.pt'), + ('resreg-scratch', 'runs/ep_resreg_scratch.pt'), + ('fast-adaptive', 'runs/ep_fast_adaptive.pt'), + ('warm-fast-live', 'runs/ep_warm_fast.pt'), + ('self-restart-live', 'runs/ep_self_restart.pt'), +] + + +def load_op(path): + torch.manual_seed(0) + blk = L.EQBlock(512, 16, 256, 256, c=1.0, attn_mode='thick') + blk.qknorm = True # canonical C512 recipe flag (not stored in ckpt) + if path is None: + return blk, 'random init' + for attempt in range(2): # live ckpts: retry once in case of a mid-save read + try: + ck = torch.load(path, map_location=L.dev); break + except Exception: + if attempt: raise + import time; time.sleep(5) + with torch.no_grad(): + for p, w in zip(blk.allp, ck['allp']): + p.copy_(w.to(L.dev)) + return blk, f"step {ck.get('step')} best {ck.get('best', float('nan')):.4f}" + + +def sym_ends(blk, z): # both ends of Sym(J_nc) via Lanczos + sh, n = z.shape, z.numel() + + def mv(x): + v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh) + with torch.no_grad(): + Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1]) + return Sv.reshape(-1).double().cpu().numpy() + + A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64) + kw = dict(k=1, tol=1e-3, return_eigenvectors=False, maxiter=600) + return float(sla.eigsh(A, which='LA', **kw)[0]), float(sla.eigsh(A, which='SA', **kw)[0]) + + +def old_pi(blk, z, iters=3): # the PRE-FIX estimator, reproduced verbatim + torch.manual_seed(1) + v = torch.randn_like(z); v = v / (v.norm() + 1e-12) + with torch.no_grad(): + for _ in range(iters): + Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1]) + v = Sv / (Sv.norm() + 1e-12) + return float((v * jvp(blk.nc_force, z, v)[1]).sum() / (v * v).sum()) + + +def main(): + rows = [] + for name, path in CKPTS: + if path and not Path(path).exists(): + print(f"[skip] {name}: {path} missing", flush=True); continue + blk, info = load_op(path) + torch.manual_seed(42) # SAME batches for every operator + acc = {k: [] for k in ('la', 'sa', 'op', 'np', 'res')} + for b in range(NB): + idx, _ = L.get_batch('train', B, 256) + xin = blk.embed(idx).detach() + zs = L.relax(blk, xin.clone(), xin, T1, EPS) + res = (L.relax(blk, zs, xin, 1, EPS) - zs).norm().item() + la, sa = sym_ends(blk, zs) + opi = old_pi(blk, zs) + _, npi = num_abscissa(blk, zs, {}, iters=40) + for k, x in zip(('la', 'sa', 'op', 'np', 'res'), (la, sa, opi, npi, res)): + acc[k].append(x) + print(f" [{name} b{b}] lam_max={la:+8.3f} lam_min={sa:+8.3f} oldPI={opi:+8.3f} " + f"newPI={npi:+8.3f} res={res:.2e}", flush=True) + val = L.evaluate(blk, T1, EPS) + m = {k: sum(v) / len(v) for k, v in acc.items()} + rows.append((name, info, m, val)) + wrong_end = abs(m['op'] - m['sa']) < abs(m['op'] - m['la']) + print(f"[{name}] ({info}) lam_max={m['la']:+.3f} lam_min={m['sa']:+.3f} oldPI={m['op']:+.3f}" + f"{' <-- oldPI tracked lam_MIN (bug live)' if wrong_end else ' (oldPI ~ lam_max, bug latent)'}" + f" newPI={m['np']:+.3f} res={m['res']:.2e} val={val:.4f}\n", flush=True) + + print(f"{'operator':<20}{'lam_max':>9}{'lam_min':>9}{'oldPI':>9}{'newPI':>9}{'res':>10}{'val':>8}") + for name, info, m, val in rows: + print(f"{name:<20}{m['la']:>+9.3f}{m['sa']:>+9.3f}{m['op']:>+9.3f}{m['np']:>+9.3f}" + f"{m['res']:>10.2e}{val:>8.4f} # {info}") + print("EIG_RECHECK_DONE", flush=True) + + +if __name__ == '__main__': + main() |
