"""Gradient-quality probe at a realistic operating point. Pretrain the thick block with BPTT for 300 steps (lands near where good optima live, res ~1e-2-1e-3, NO contraction penalty), then measure cosine(EP gradient, long-horizon-BPTT reference) per parameter group as a function of free-phase length T1 (= residual level) and nudge length T2. If cosine is high at res ~1e-2 -> there is NO estimator wall at the operating points that matter; the EP-vs-BPTT gap is speed + regularization tax, not a convergence mechanism.""" import torch from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) B, T, C, H = 16, 64, 128, 4 blk = EQBlock(C, H, 256, T, attn_mode='thick') # c=1.0 default = thick-BPTT baseline setting import os if os.path.exists('/tmp/lt_ep/probe_w.pt'): for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')): with torch.no_grad(): p.copy_(w.to(dev)) print("loaded cached pretrained weights", flush=True) else: opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) for step in range(300): idx, y = get_batch('train', B, T) g = bptt_step(blk, idx, y, 150, 0.1) opt.zero_grad(set_to_none=True) for p in blk.allp: p.grad = g.get(id(p)) torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) opt.step() torch.save([p.detach().cpu() for p in blk.allp], '/tmp/lt_ep/probe_w.pt') print("pretrained 300 BPTT steps (thick, c=1) -- measuring at this operating point", flush=True) groups = {'all': blk.block, 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], 'emb': [blk.tok, blk.pos], 'mem': [blk.Wm]} def cos(ga, gb, ps): keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] if not keep: na = sum(1 for p in ps if ga.get(id(p)) is None) nb = sum(1 for p in ps if gb.get(id(p)) is None) print(f"\n[debug] empty keep: |ps|={len(ps)} ga_None={na} gb_None={nb} " f"ga_keys={len(ga)} gb_keys={len(gb)}", flush=True) return float('nan') va = torch.cat([ga[id(p)].reshape(-1) for p in keep]) vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() hdr = f"{'config':>16} {'res':>9} " + " ".join(f"{k:>6}" for k in groups) for bi in range(3): idx, y = get_batch('train', B, T) ref = bptt_step(blk, idx, y, 400, 0.1) # exact reference: long-horizon BPTT print(("\n" if bi else "") + hdr, flush=True) g150 = bptt_step(blk, idx, y, 150, 0.1) print(f"{'bptt T1=150':>16} {'--':>9} " + " ".join(f"{cos(g150, ref, ps):>6.3f}" for ps in groups.values()), flush=True) for T1 in (50, 150, 400): gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0) print(f"{f'ep T1={T1:<3} T2=20':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) gep, res = ep_step(blk, idx, y, 150, 60, 0.1, 0.02, 0.0) print(f"{'ep T1=150 T2=60':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)