"""#1 — from-scratch-plateau diagnostic: cos(EP gradient, exact BPTT gradient) over training, plus an operator FINGERPRINT for comparing checkpoints. Hypothesis (from the resreg_probe): the scratch run spends its formative high-lr phase at free-phase residual ~1e-2, where the EP estimate is only ~0.72-aligned with the exact BPTT gradient -> it descends on a mediocre gradient and plateaus above BPTT's floor; a warm start from a conditioned operator (res~1e-4, cos~0.98) skips that phase. This logs (step, resT1, cos, val) so scratch vs warm trajectories can be laid side by side, and fingerprints any checkpoint (res, cos, numerical abscissa, val) so we can see WHAT distinguishes s2000 from other 2000-step checkpoints (conditioning? alignment? abscissa?).""" import math, torch from lt_ep_train import ep_step, bptt_step, relax, evaluate, get_batch from eig_control import lead_rho def _cos(ge, gb, params): # cosine over the shared block params (where EP != BPTT) dot = ne = nb = 0.0 for p in params: a, b = ge.get(id(p)), gb.get(id(p)) if a is None or b is None: continue dot += float((a * b).sum()); ne += float((a * a).sum()); nb += float((b * b).sum()) return dot / (math.sqrt(ne * nb) + 1e-20) def cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, bsub=6): """cos(EP grad, exact BPTT grad) on a SMALL sub-batch (the exact-BPTT unroll graph is memory-heavy at C512/T1=150, so we slice to `bsub` rows). Both computed jacreg-free for a clean comparison.""" idx, y = idx[:bsub], y[:bsub] ge, res = ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=holo, hr=hr, t1max=t1max, res_est=res_est, t2sel=t2sel, corr_every=1, res_gate=0.0, resreg=0.0) gb = bptt_step(blk, idx, y, T1, eps, jacreg=0.0) return _cos(ge, gb, blk.block), res def fingerprint(blk, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, nb=4, B=6): """Median (res, cos-to-BPTT, TRUE leading map-eigenvalue rho & Re mu) over nb small batches + val CE. B kept small: the exact-BPTT reference gradient unrolls T1 steps and is memory-heavy at C512. (The numerical-abscissa column was retracted 2026-07-03 — omega is ~10 above the true abscissa on this non-normal operator and anti-correlates with stability; see eig_control docstring / eig_recheck.)""" cache = {}; res_l, cos_l, rho_l, mu_l = [], [], [], [] for _ in range(nb): idx, y = get_batch('train', B, blk.T) c, r = cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo, hr, t1max, res_est, t2sel) xin = blk.embed(idx).detach(); zs = relax(blk, xin.clone(), xin, T1, eps) _, rho, mu = lead_rho(blk, zs, eps, blk.c, cache, iters=25) # one-shot: more iters than the warm-started training path res_l.append(r); cos_l.append(c); rho_l.append(rho); mu_l.append(mu) md = lambda a: sorted(a)[len(a) // 2] return dict(res=md(res_l), cos=md(cos_l), rho=md(rho_l), mu_re=md(mu_l), val=evaluate(blk, T1, eps))