summaryrefslogtreecommitdiff
path: root/ep_run/diag_cos.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/diag_cos.py')
-rw-r--r--ep_run/diag_cos.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/ep_run/diag_cos.py b/ep_run/diag_cos.py
index 37e7257..fb8a68f 100644
--- a/ep_run/diag_cos.py
+++ b/ep_run/diag_cos.py
@@ -9,7 +9,7 @@ can be laid side by side, and fingerprints any checkpoint (res, cos, numerical a
see WHAT distinguishes s2000 from other 2000-step checkpoints (conditioning? alignment? abscissa?)."""
import math, torch
from lt_ep_train import ep_step, bptt_step, relax, evaluate, get_batch
-from eig_control import num_abscissa
+from eig_control import lead_rho
def _cos(ge, gb, params): # cosine over the shared block params (where EP != BPTT)
@@ -32,14 +32,16 @@ def cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_es
def fingerprint(blk, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, nb=4, B=6):
- """Median (res, cos-to-BPTT, numerical abscissa) over nb small batches + val CE — the operator's 4-D fingerprint.
- B kept small: the exact-BPTT reference gradient unrolls T1 steps and is memory-heavy at C512."""
- cache = {}; res_l, cos_l, om_l = [], [], []
+ """Median (res, cos-to-BPTT, TRUE leading map-eigenvalue rho & Re mu) over nb small batches + val CE.
+ B kept small: the exact-BPTT reference gradient unrolls T1 steps and is memory-heavy at C512.
+ (The numerical-abscissa column was retracted 2026-07-03 — omega is ~10 above the true abscissa on
+ this non-normal operator and anti-correlates with stability; see eig_control docstring / eig_recheck.)"""
+ cache = {}; res_l, cos_l, rho_l, mu_l = [], [], [], []
for _ in range(nb):
idx, y = get_batch('train', B, blk.T)
c, r = cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo, hr, t1max, res_est, t2sel)
xin = blk.embed(idx).detach(); zs = relax(blk, xin.clone(), xin, T1, eps)
- _, om = num_abscissa(blk, zs, cache)
- res_l.append(r); cos_l.append(c); om_l.append(om)
+ _, rho, mu = lead_rho(blk, zs, eps, blk.c, cache, iters=25) # one-shot: more iters than the warm-started training path
+ res_l.append(r); cos_l.append(c); rho_l.append(rho); mu_l.append(mu)
md = lambda a: sorted(a)[len(a) // 2]
- return dict(res=md(res_l), cos=md(cos_l), num_abscissa=md(om_l), val=evaluate(blk, T1, eps))
+ return dict(res=md(res_l), cos=md(cos_l), rho=md(rho_l), mu_re=md(mu_l), val=evaluate(blk, T1, eps))