"""Decisive 'why is EP far at S1' diagnostic: separate estimator BIAS from VARIANCE at the converged v4b checkpoint. Over N batches compute EP grad, BPTT-400 grad, BPTT-150 control. mean-cos = mean_b cos(g_EP^b, g_BPTT^b) -> per-step quality (noisy) cos-means = cos(sum_b g_EP, sum_b g_BPTT) -> if >> mean-cos: errors AVERAGE OUT = VARIANCE if ~ mean-cos: systematic = BIAS (the real wall) BPTT-150-vs-400 gives the same two metrics as the slow-mixing horizon baseline.""" import torch import lt_ep_train as M from pathlib import Path import pickle M.DD = Path('/tmp/lt_ep/data/tinystories') M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step dev = 'cuda' torch.manual_seed(0) B, T, C, H = 8, 256, 256, 8 blk = EQBlock(C, H, 256, T, attn_mode='thick') blk.qknorm = False; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0; blk._cstep = None ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt') for p, w in zip(blk.allp, ck['allp']): with torch.no_grad(): p.copy_(w.to(dev)) print(f"v4b ckpt best {ck['best']:.4f}", flush=True) groups = {'all': blk.block, 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], 'emb': [blk.tok, blk.pos]} N = 16 sEP, s400, s150 = {}, {}, {} cos_b = {k: [] for k in groups} bctl_b = {k: [] for k in groups} def flat(g, ps): v = [g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None] return torch.cat(v) if v else None def cos(a, b): return (a @ b / (a.norm() * b.norm() + 1e-12)).item() for i in range(N): idx, y = get_batch('train', B, T) gE, _ = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120) g4 = bptt_step(blk, idx, y, 400, 0.1) g1 = bptt_step(blk, idx, y, 150, 0.1) for k, ps in groups.items(): a, b, c = flat(gE, ps), flat(g4, ps), flat(g1, ps) if a is not None and b is not None: cos_b[k].append(cos(a, b)) if c is not None and b is not None: bctl_b[k].append(cos(c, b)) for src, acc in ((gE, sEP), (g4, s400), (g1, s150)): for p in blk.block: if src.get(id(p)) is not None: acc[id(p)] = src[id(p)].detach().clone() if id(p) not in acc else acc[id(p)] + src[id(p)].detach() print(f" batch {i+1}/{N} done", flush=True) print(f"\n{'group':>5} {'EP mean-cos':>12} {'EP cos-means':>13} {'BPTT mean-cos':>14} {'BPTT cos-means':>15}") for k, ps in groups.items(): mc = sum(cos_b[k]) / len(cos_b[k]) bmc = sum(bctl_b[k]) / len(bctl_b[k]) aE, a4, a1 = flat(sEP, ps), flat(s400, ps), flat(s150, ps) cm = cos(aE, a4) bcm = cos(a1, a4) print(f"{k:>5} {mc:>12.3f} {cm:>13.3f} {bmc:>14.3f} {bcm:>15.3f}", flush=True)