diff options
Diffstat (limited to 'ep_run/diag_cos.py')
| -rw-r--r-- | ep_run/diag_cos.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/ep_run/diag_cos.py b/ep_run/diag_cos.py new file mode 100644 index 0000000..37e7257 --- /dev/null +++ b/ep_run/diag_cos.py @@ -0,0 +1,45 @@ +"""#1 — from-scratch-plateau diagnostic: cos(EP gradient, exact BPTT gradient) over training, plus an +operator FINGERPRINT for comparing checkpoints. + +Hypothesis (from the resreg_probe): the scratch run spends its formative high-lr phase at free-phase +residual ~1e-2, where the EP estimate is only ~0.72-aligned with the exact BPTT gradient -> it descends +on a mediocre gradient and plateaus above BPTT's floor; a warm start from a conditioned operator +(res~1e-4, cos~0.98) skips that phase. This logs (step, resT1, cos, val) so scratch vs warm trajectories +can be laid side by side, and fingerprints any checkpoint (res, cos, numerical abscissa, val) so we can +see WHAT distinguishes s2000 from other 2000-step checkpoints (conditioning? alignment? abscissa?).""" +import math, torch +from lt_ep_train import ep_step, bptt_step, relax, evaluate, get_batch +from eig_control import num_abscissa + + +def _cos(ge, gb, params): # cosine over the shared block params (where EP != BPTT) + dot = ne = nb = 0.0 + for p in params: + a, b = ge.get(id(p)), gb.get(id(p)) + if a is None or b is None: continue + dot += float((a * b).sum()); ne += float((a * a).sum()); nb += float((b * b).sum()) + return dot / (math.sqrt(ne * nb) + 1e-20) + + +def cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, bsub=6): + """cos(EP grad, exact BPTT grad) on a SMALL sub-batch (the exact-BPTT unroll graph is memory-heavy at + C512/T1=150, so we slice to `bsub` rows). Both computed jacreg-free for a clean comparison.""" + idx, y = idx[:bsub], y[:bsub] + ge, res = ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=holo, hr=hr, t1max=t1max, + res_est=res_est, t2sel=t2sel, corr_every=1, res_gate=0.0, resreg=0.0) + gb = bptt_step(blk, idx, y, T1, eps, jacreg=0.0) + return _cos(ge, gb, blk.block), res + + +def fingerprint(blk, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, nb=4, B=6): + """Median (res, cos-to-BPTT, numerical abscissa) over nb small batches + val CE — the operator's 4-D fingerprint. + B kept small: the exact-BPTT reference gradient unrolls T1 steps and is memory-heavy at C512.""" + cache = {}; res_l, cos_l, om_l = [], [], [] + for _ in range(nb): + idx, y = get_batch('train', B, blk.T) + c, r = cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo, hr, t1max, res_est, t2sel) + xin = blk.embed(idx).detach(); zs = relax(blk, xin.clone(), xin, T1, eps) + _, om = num_abscissa(blk, zs, cache) + res_l.append(r); cos_l.append(c); om_l.append(om) + md = lambda a: sorted(a)[len(a) // 2] + return dict(res=md(res_l), cos=md(cos_l), num_abscissa=md(om_l), val=evaluate(blk, T1, eps)) |
