summaryrefslogtreecommitdiff
path: root/ep_run/resreg_probe.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/resreg_probe.py')
-rw-r--r--ep_run/resreg_probe.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/ep_run/resreg_probe.py b/ep_run/resreg_probe.py
new file mode 100644
index 0000000..8a5ba0a
--- /dev/null
+++ b/ep_run/resreg_probe.py
@@ -0,0 +1,62 @@
+"""Does resreg CONTAMINATE the gradient or add back BPTT's missing residual-defense?
+At a ckpt compute the true grad g_BPTT, the pure EP estimate g_VF, and the resreg grad g_R (at the
+training scale lam). If cos(g_VF + lam*g_R, g_BPTT) >= cos(g_VF, g_BPTT), resreg moves EP TOWARD the
+true gradient (correction). If it drops, resreg is contaminating."""
+import argparse, pickle, math, torch
+from pathlib import Path
+import lt_ep_train as L
+from lt_ep_train import EQBlock, ep_step, bptt_step, relax
+
+ap = argparse.ArgumentParser()
+ap.add_argument('--ckpt', required=True); ap.add_argument('--data', default='data/tinystories_bpe')
+ap.add_argument('--gelu', default='erf'); ap.add_argument('--C', type=int, default=512)
+ap.add_argument('--H', type=int, default=16); ap.add_argument('--Mm', type=int, default=256)
+ap.add_argument('--T', type=int, default=256); ap.add_argument('--B', type=int, default=8)
+ap.add_argument('--T1', type=int, default=150); ap.add_argument('--T2', type=int, default=20)
+ap.add_argument('--eps', type=float, default=0.1); ap.add_argument('--beta', type=float, default=0.02)
+ap.add_argument('--resreg', type=float, default=0.2); ap.add_argument('--t1max', type=int, default=300)
+ap.add_argument('--res_est', type=float, default=1e-4); ap.add_argument('--t2sel', type=int, default=40)
+ap.add_argument('--hr', type=float, default=0.02)
+cfg = ap.parse_args(); dev = 'cuda'
+L.DD = Path(cfg.data); L.vocab = pickle.load(open(L.DD / 'meta.pkl', 'rb'))['vocab_size']
+torch.manual_seed(0)
+blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=1.0, attn_mode='thick')
+blk.qknorm = True; blk.fnoise = 0.0; blk._cstep = None; blk.navg = 1; blk.li_avg = 0; blk.track = True
+blk.nbrake = 0.0; blk.gelu = cfg.gelu
+ck = torch.load(cfg.ckpt, map_location=dev)
+with torch.no_grad():
+ for p, w in zip(blk.allp, ck['allp']): p.copy_(w.to(dev))
+idx, y = L.get_batch('train', cfg.B, cfg.T)
+
+def flat(gd, ps):
+ return torch.cat([gd[id(p)].reshape(-1) if gd.get(id(p)) is not None
+ else torch.zeros(p.numel(), device=dev) for p in ps])
+def cos(a, b): return (a @ b / (a.norm() * b.norm() + 1e-20)).item()
+
+gB = bptt_step(blk, idx, y, cfg.T1, cfg.eps, 0.0) # TRUE gradient
+gVF, _ = ep_step(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, 0.0, 2, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel, 1, 0.0)
+xin0 = blk.embed(idx).detach()
+zT = relax(blk, xin0.clone(), xin0, cfg.T1, cfg.eps)
+resT1 = (relax(blk, zT, xin0, 1, cfg.eps) - zT).norm().item() / (zT.norm().item() + 1e-9)
+with torch.enable_grad():
+ Fz = blk.tforce(zT, xin0); Rr = (cfg.eps * Fz).pow(2).sum() / (zT.pow(2).sum() + 1e-9)
+ grr = torch.autograd.grad(Rr, blk.block, allow_unused=True)
+gR = {id(p): (g if g is not None else torch.zeros_like(p)) for p, g in zip(blk.block, grr)}
+B, VF, R = flat(gB, blk.block), flat(gVF, blk.block), flat(gR, blk.block)
+ratio = cfg.resreg * min(1.0, resT1 / 2e-2)
+lam = ratio * VF.norm() / (R.norm() + 1e-20)
+TOT = VF + lam * R
+print(f"# ckpt step {ck.get('step')} best {ck.get('best')} resT1={resT1:.2e} ratio={ratio:.3f}")
+print(f"|VF|={VF.norm():.2e} |lam*R|={ (lam*R.norm()).item():.2e} realized ratio={(lam*R.norm()/VF.norm()).item():.3f}")
+print(f"cos(VF, BPTT) = {cos(VF, B):+.4f} <- EP estimate, NO resreg")
+print(f"cos(VF+lam*R, BPTT) = {cos(TOT, B):+.4f} <- WITH resreg (training grad)")
+d = cos(TOT, B) - cos(VF, B)
+print(f" delta = {d:+.4f} => {'resreg ADDS alignment (correction, not contamination)' if d >= -1e-3 else 'resreg HURTS alignment (CONTAMINATION)'}")
+print(f"cos(R, BPTT) = {cos(R, B):+.4f} <- resreg dir vs true grad (aligned? >0 means resreg points toward BPTT)")
+print(f"cos(R, VF) = {cos(R, VF):+.4f}")
+# --- M = g_BPTT - g_EP : the finite-horizon stabilizer BPTT HAS and EP LACKS ---
+M = B - VF
+print(f"--- M = g_BPTT - g_EP : what EP is missing vs BPTT ---")
+print(f"|M|/|BPTT| = {(M.norm()/(B.norm()+1e-20)).item():.3f} |M|/|EP| = {(M.norm()/(VF.norm()+1e-20)).item():.3f}")
+print(f"cos(M, resreg R) = {cos(M, R):+.4f} <- does resreg point where the MISSING term is? (>0 = resreg's intent correct)")
+print(f"cos(M, BPTT) = {cos(M, B):+.4f} cos(M, EP) = {cos(M, VF):+.4f}")