summaryrefslogtreecommitdiff
path: root/ep_run/diag_cos.py
blob: 37e725730a07af071313015de6bc436159572eec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""#1 — from-scratch-plateau diagnostic: cos(EP gradient, exact BPTT gradient) over training, plus an
operator FINGERPRINT for comparing checkpoints.

Hypothesis (from the resreg_probe): the scratch run spends its formative high-lr phase at free-phase
residual ~1e-2, where the EP estimate is only ~0.72-aligned with the exact BPTT gradient -> it descends
on a mediocre gradient and plateaus above BPTT's floor; a warm start from a conditioned operator
(res~1e-4, cos~0.98) skips that phase. This logs (step, resT1, cos, val) so scratch vs warm trajectories
can be laid side by side, and fingerprints any checkpoint (res, cos, numerical abscissa, val) so we can
see WHAT distinguishes s2000 from other 2000-step checkpoints (conditioning? alignment? abscissa?)."""
import math, torch
from lt_ep_train import ep_step, bptt_step, relax, evaluate, get_batch
from eig_control import num_abscissa


def _cos(ge, gb, params):                                # cosine over the shared block params (where EP != BPTT)
    dot = ne = nb = 0.0
    for p in params:
        a, b = ge.get(id(p)), gb.get(id(p))
        if a is None or b is None: continue
        dot += float((a * b).sum()); ne += float((a * a).sum()); nb += float((b * b).sum())
    return dot / (math.sqrt(ne * nb) + 1e-20)


def cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, bsub=6):
    """cos(EP grad, exact BPTT grad) on a SMALL sub-batch (the exact-BPTT unroll graph is memory-heavy at
    C512/T1=150, so we slice to `bsub` rows). Both computed jacreg-free for a clean comparison."""
    idx, y = idx[:bsub], y[:bsub]
    ge, res = ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=holo, hr=hr, t1max=t1max,
                      res_est=res_est, t2sel=t2sel, corr_every=1, res_gate=0.0, resreg=0.0)
    gb = bptt_step(blk, idx, y, T1, eps, jacreg=0.0)
    return _cos(ge, gb, blk.block), res


def fingerprint(blk, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, nb=4, B=6):
    """Median (res, cos-to-BPTT, numerical abscissa) over nb small batches + val CE — the operator's 4-D fingerprint.
    B kept small: the exact-BPTT reference gradient unrolls T1 steps and is memory-heavy at C512."""
    cache = {}; res_l, cos_l, om_l = [], [], []
    for _ in range(nb):
        idx, y = get_batch('train', B, blk.T)
        c, r = cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo, hr, t1max, res_est, t2sel)
        xin = blk.embed(idx).detach(); zs = relax(blk, xin.clone(), xin, T1, eps)
        _, om = num_abscissa(blk, zs, cache)
        res_l.append(r); cos_l.append(c); om_l.append(om)
    md = lambda a: sorted(a)[len(a) // 2]
    return dict(res=md(res_l), cos=md(cos_l), num_abscissa=md(om_l), val=evaluate(blk, T1, eps))