"""#1 — from-scratch-plateau diagnostic: cos(EP gradient, exact BPTT gradient) over training, plus an operator FINGERPRINT for comparing checkpoints. Hypothesis (from the resreg_probe): the scratch run spends its formative high-lr phase at free-phase residual ~1e-2, where the EP estimate is only ~0.72-aligned with the exact BPTT gradient -> it descends on a mediocre gradient and plateaus above BPTT's floor; a warm start from a conditioned operator (res~1e-4, cos~0.98) skips that phase. This logs (step, resT1, cos, val) so scratch vs warm trajectories can be laid side by side, and fingerprints any checkpoint (res, cos, numerical abscissa, val) so we can see WHAT distinguishes s2000 from other 2000-step checkpoints (conditioning? alignment? abscissa?).""" import math, torch from lt_ep_train import ep_step, bptt_step, relax, evaluate, get_batch from eig_control import num_abscissa def _cos(ge, gb, params): # cosine over the shared block params (where EP != BPTT) dot = ne = nb = 0.0 for p in params: a, b = ge.get(id(p)), gb.get(id(p)) if a is None or b is None: continue dot += float((a * b).sum()); ne += float((a * a).sum()); nb += float((b * b).sum()) return dot / (math.sqrt(ne * nb) + 1e-20) def cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, bsub=6): """cos(EP grad, exact BPTT grad) on a SMALL sub-batch (the exact-BPTT unroll graph is memory-heavy at C512/T1=150, so we slice to `bsub` rows). Both computed jacreg-free for a clean comparison.""" idx, y = idx[:bsub], y[:bsub] ge, res = ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=holo, hr=hr, t1max=t1max, res_est=res_est, t2sel=t2sel, corr_every=1, res_gate=0.0, resreg=0.0) gb = bptt_step(blk, idx, y, T1, eps, jacreg=0.0) return _cos(ge, gb, blk.block), res def fingerprint(blk, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, nb=4, B=6): """Median (res, cos-to-BPTT, numerical abscissa) over nb small batches + val CE — the operator's 4-D fingerprint. B kept small: the exact-BPTT reference gradient unrolls T1 steps and is memory-heavy at C512.""" cache = {}; res_l, cos_l, om_l = [], [], [] for _ in range(nb): idx, y = get_batch('train', B, blk.T) c, r = cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo, hr, t1max, res_est, t2sel) xin = blk.embed(idx).detach(); zs = relax(blk, xin.clone(), xin, T1, eps) _, om = num_abscissa(blk, zs, cache) res_l.append(r); cos_l.append(c); om_l.append(om) md = lambda a: sorted(a)[len(a) // 2] return dict(res=md(res_l), cos=md(cos_l), num_abscissa=md(om_l), val=evaluate(blk, T1, eps))