summaryrefslogtreecommitdiff
path: root/ep_run/eig_recheck.py
blob: 317a40d6cef44020e1c2dbf12243c440a1360227 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""eig_recheck.py — gold-standard recheck of the warm/scratch abscissa story after the num_abscissa
shift fix. Plain power iteration (the pre-fix estimator) converges to the largest-|lambda| end of the
indefinite Sym(J_nc) — so the reported 's2000 = -10.14 vs scratch = +1.11' may have compared lambda_min
against lambda_max. For each operator this measures, at z* on the SAME seeded batches:
  lam_max / lam_min  — both ENDS of Sym(J_nc), Lanczos (scipy eigsh LA/SA, matvec-only, gold standard)
  oldPI              — what the pre-fix plain power iteration returned (bug reproduction)
  newPI              — the fixed shifted-PI estimate (validates the training-time estimator vs Lanczos)
  res                — T1-residual  eps*||F(z*)||  (the lt_ep_train line-455 convention)
  val                — val CE (identifies the checkpoint on the training curve)
Verdict logic: oldPI ~= lam_min (when |lam_min| > lam_max) ==> the bug was live and the -10 story is
about the WRONG end; the warm/scratch contrast must be re-read off the lam_max column.
"""
import numpy as np, torch, scipy.sparse.linalg as sla
from pathlib import Path
from torch.autograd.functional import jvp, vjp
import lt_ep_train as L
from eig_control import num_abscissa

T1, EPS, B, NB = 150, 0.1, 6, 2

CKPTS = [                                                     # (label, path or None=random init)
    ('rand-init', None),
    ('s2000-warmsrc', 'runs/redx_traj/s2000.pt'),
    ('resreg-scratch', 'runs/ep_resreg_scratch.pt'),
    ('fast-adaptive', 'runs/ep_fast_adaptive.pt'),
    ('warm-fast-live', 'runs/ep_warm_fast.pt'),
    ('self-restart-live', 'runs/ep_self_restart.pt'),
]


def load_op(path):
    torch.manual_seed(0)
    blk = L.EQBlock(512, 16, 256, 256, c=1.0, attn_mode='thick')
    blk.qknorm = True                                         # canonical C512 recipe flag (not stored in ckpt)
    if path is None:
        return blk, 'random init'
    for attempt in range(2):                                  # live ckpts: retry once in case of a mid-save read
        try:
            ck = torch.load(path, map_location=L.dev); break
        except Exception:
            if attempt: raise
            import time; time.sleep(5)
    with torch.no_grad():
        for p, w in zip(blk.allp, ck['allp']):
            p.copy_(w.to(L.dev))
    return blk, f"step {ck.get('step')} best {ck.get('best', float('nan')):.4f}"


def sym_ends(blk, z):                                         # both ends of Sym(J_nc) via Lanczos
    sh, n = z.shape, z.numel()

    def mv(x):
        v = torch.from_numpy(np.asarray(x, dtype=np.float32)).to(L.dev).view(sh)
        with torch.no_grad():
            Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1])
        return Sv.reshape(-1).double().cpu().numpy()

    A = sla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
    kw = dict(k=1, tol=1e-3, return_eigenvectors=False, maxiter=600)
    return float(sla.eigsh(A, which='LA', **kw)[0]), float(sla.eigsh(A, which='SA', **kw)[0])


def old_pi(blk, z, iters=3):                                  # the PRE-FIX estimator, reproduced verbatim
    torch.manual_seed(1)
    v = torch.randn_like(z); v = v / (v.norm() + 1e-12)
    with torch.no_grad():
        for _ in range(iters):
            Sv = 0.5 * (jvp(blk.nc_force, z, v)[1] + vjp(blk.nc_force, z, v)[1])
            v = Sv / (Sv.norm() + 1e-12)
        return float((v * jvp(blk.nc_force, z, v)[1]).sum() / (v * v).sum())


def main():
    rows = []
    for name, path in CKPTS:
        if path and not Path(path).exists():
            print(f"[skip] {name}: {path} missing", flush=True); continue
        blk, info = load_op(path)
        torch.manual_seed(42)                                 # SAME batches for every operator
        acc = {k: [] for k in ('la', 'sa', 'op', 'np', 'res')}
        for b in range(NB):
            idx, _ = L.get_batch('train', B, 256)
            xin = blk.embed(idx).detach()
            zs = L.relax(blk, xin.clone(), xin, T1, EPS)
            res = (L.relax(blk, zs, xin, 1, EPS) - zs).norm().item()
            la, sa = sym_ends(blk, zs)
            opi = old_pi(blk, zs)
            _, npi = num_abscissa(blk, zs, {}, iters=40)
            for k, x in zip(('la', 'sa', 'op', 'np', 'res'), (la, sa, opi, npi, res)):
                acc[k].append(x)
            print(f"  [{name} b{b}] lam_max={la:+8.3f} lam_min={sa:+8.3f} oldPI={opi:+8.3f} "
                  f"newPI={npi:+8.3f} res={res:.2e}", flush=True)
        val = L.evaluate(blk, T1, EPS)
        m = {k: sum(v) / len(v) for k, v in acc.items()}
        rows.append((name, info, m, val))
        wrong_end = abs(m['op'] - m['sa']) < abs(m['op'] - m['la'])
        print(f"[{name}] ({info}) lam_max={m['la']:+.3f} lam_min={m['sa']:+.3f} oldPI={m['op']:+.3f}"
              f"{' <-- oldPI tracked lam_MIN (bug live)' if wrong_end else ' (oldPI ~ lam_max, bug latent)'}"
              f" newPI={m['np']:+.3f} res={m['res']:.2e} val={val:.4f}\n", flush=True)

    print(f"{'operator':<20}{'lam_max':>9}{'lam_min':>9}{'oldPI':>9}{'newPI':>9}{'res':>10}{'val':>8}")
    for name, info, m, val in rows:
        print(f"{name:<20}{m['la']:>+9.3f}{m['sa']:>+9.3f}{m['op']:>+9.3f}{m['np']:>+9.3f}"
              f"{m['res']:>10.2e}{val:>8.4f}   # {info}")
    print("EIG_RECHECK_DONE", flush=True)


if __name__ == '__main__':
    main()