summaryrefslogtreecommitdiff
path: root/ep_run/diag_cos.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/diag_cos.py')
-rw-r--r--ep_run/diag_cos.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/ep_run/diag_cos.py b/ep_run/diag_cos.py
new file mode 100644
index 0000000..37e7257
--- /dev/null
+++ b/ep_run/diag_cos.py
@@ -0,0 +1,45 @@
+"""#1 — from-scratch-plateau diagnostic: cos(EP gradient, exact BPTT gradient) over training, plus an
+operator FINGERPRINT for comparing checkpoints.
+
+Hypothesis (from the resreg_probe): the scratch run spends its formative high-lr phase at free-phase
+residual ~1e-2, where the EP estimate is only ~0.72-aligned with the exact BPTT gradient -> it descends
+on a mediocre gradient and plateaus above BPTT's floor; a warm start from a conditioned operator
+(res~1e-4, cos~0.98) skips that phase. This logs (step, resT1, cos, val) so scratch vs warm trajectories
+can be laid side by side, and fingerprints any checkpoint (res, cos, numerical abscissa, val) so we can
+see WHAT distinguishes s2000 from other 2000-step checkpoints (conditioning? alignment? abscissa?)."""
+import math, torch
+from lt_ep_train import ep_step, bptt_step, relax, evaluate, get_batch
+from eig_control import num_abscissa
+
+
+def _cos(ge, gb, params): # cosine over the shared block params (where EP != BPTT)
+ dot = ne = nb = 0.0
+ for p in params:
+ a, b = ge.get(id(p)), gb.get(id(p))
+ if a is None or b is None: continue
+ dot += float((a * b).sum()); ne += float((a * a).sum()); nb += float((b * b).sum())
+ return dot / (math.sqrt(ne * nb) + 1e-20)
+
+
+def cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, bsub=6):
+ """cos(EP grad, exact BPTT grad) on a SMALL sub-batch (the exact-BPTT unroll graph is memory-heavy at
+ C512/T1=150, so we slice to `bsub` rows). Both computed jacreg-free for a clean comparison."""
+ idx, y = idx[:bsub], y[:bsub]
+ ge, res = ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=holo, hr=hr, t1max=t1max,
+ res_est=res_est, t2sel=t2sel, corr_every=1, res_gate=0.0, resreg=0.0)
+ gb = bptt_step(blk, idx, y, T1, eps, jacreg=0.0)
+ return _cos(ge, gb, blk.block), res
+
+
+def fingerprint(blk, T1, T2, eps, beta, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, nb=4, B=6):
+ """Median (res, cos-to-BPTT, numerical abscissa) over nb small batches + val CE — the operator's 4-D fingerprint.
+ B kept small: the exact-BPTT reference gradient unrolls T1 steps and is memory-heavy at C512."""
+ cache = {}; res_l, cos_l, om_l = [], [], []
+ for _ in range(nb):
+ idx, y = get_batch('train', B, blk.T)
+ c, r = cos_ep_bptt(blk, idx, y, T1, T2, eps, beta, holo, hr, t1max, res_est, t2sel)
+ xin = blk.embed(idx).detach(); zs = relax(blk, xin.clone(), xin, T1, eps)
+ _, om = num_abscissa(blk, zs, cache)
+ res_l.append(r); cos_l.append(c); om_l.append(om)
+ md = lambda a: sorted(a)[len(a) // 2]
+ return dict(res=md(res_l), cos=md(cos_l), num_abscissa=md(om_l), val=evaluate(blk, T1, eps))