summaryrefslogtreecommitdiff
path: root/ep_run/grad_quality.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/grad_quality.py')
-rw-r--r--ep_run/grad_quality.py64
1 files changed, 64 insertions, 0 deletions
diff --git a/ep_run/grad_quality.py b/ep_run/grad_quality.py
new file mode 100644
index 0000000..095d764
--- /dev/null
+++ b/ep_run/grad_quality.py
@@ -0,0 +1,64 @@
+"""Gradient-quality probe at a realistic operating point. Pretrain the thick block with BPTT
+for 300 steps (lands near where good optima live, res ~1e-2-1e-3, NO contraction penalty),
+then measure cosine(EP gradient, long-horizon-BPTT reference) per parameter group as a
+function of free-phase length T1 (= residual level) and nudge length T2.
+If cosine is high at res ~1e-2 -> there is NO estimator wall at the operating points that
+matter; the EP-vs-BPTT gap is speed + regularization tax, not a convergence mechanism."""
+import torch
+from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+B, T, C, H = 16, 64, 128, 4
+blk = EQBlock(C, H, 256, T, attn_mode='thick') # c=1.0 default = thick-BPTT baseline setting
+import os
+if os.path.exists('/tmp/lt_ep/probe_w.pt'):
+ for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')):
+ with torch.no_grad():
+ p.copy_(w.to(dev))
+ print("loaded cached pretrained weights", flush=True)
+else:
+ opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
+ for step in range(300):
+ idx, y = get_batch('train', B, T)
+ g = bptt_step(blk, idx, y, 150, 0.1)
+ opt.zero_grad(set_to_none=True)
+ for p in blk.allp:
+ p.grad = g.get(id(p))
+ torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
+ opt.step()
+ torch.save([p.detach().cpu() for p in blk.allp], '/tmp/lt_ep/probe_w.pt')
+print("pretrained 300 BPTT steps (thick, c=1) -- measuring at this operating point", flush=True)
+
+groups = {'all': blk.block,
+ 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
+ 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb],
+ 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b],
+ 'emb': [blk.tok, blk.pos],
+ 'mem': [blk.Wm]}
+
+
+def cos(ga, gb, ps):
+ keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None]
+ if not keep:
+ na = sum(1 for p in ps if ga.get(id(p)) is None)
+ nb = sum(1 for p in ps if gb.get(id(p)) is None)
+ print(f"\n[debug] empty keep: |ps|={len(ps)} ga_None={na} gb_None={nb} "
+ f"ga_keys={len(ga)} gb_keys={len(gb)}", flush=True)
+ return float('nan')
+ va = torch.cat([ga[id(p)].reshape(-1) for p in keep])
+ vb = torch.cat([gb[id(p)].reshape(-1) for p in keep])
+ return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()
+
+
+hdr = f"{'config':>16} {'res':>9} " + " ".join(f"{k:>6}" for k in groups)
+for bi in range(3):
+ idx, y = get_batch('train', B, T)
+ ref = bptt_step(blk, idx, y, 400, 0.1) # exact reference: long-horizon BPTT
+ print(("\n" if bi else "") + hdr, flush=True)
+ g150 = bptt_step(blk, idx, y, 150, 0.1)
+ print(f"{'bptt T1=150':>16} {'--':>9} " + " ".join(f"{cos(g150, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
+ for T1 in (50, 150, 400):
+ gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0)
+ print(f"{f'ep T1={T1:<3} T2=20':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
+ gep, res = ep_step(blk, idx, y, 150, 60, 0.1, 0.02, 0.0)
+ print(f"{'ep T1=150 T2=60':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)