"""Diagnostic: WHY does EP training destabilize? Test the hypothesis (Ernoult 2019): EP == BPTT iff the free phase has CONVERGED (+ small beta). So log, per training step: - free-phase residual ||Pi(z*+eF)-z*||/||z*|| (is the fixed point still there?) - cosine(EP-grad, BPTT-grad) over the block params (is EP still tracking the true grad?) If cosine starts ~1 and stays ~1 until the residual blows up -> it's loss of convergence, not beta. """ import math, torch, torch.nn.functional as F from lt_ep_train import EQBlock, ep_step, bptt_step, relax, get_batch, ce dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) T1, T2, eps, beta, B, T = 80, 15, 0.1, 0.02, 32, 64 blk = EQBlock(128, 4, 256, T, s=1.0, c=1.0) opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) BP = (blk.WQ, blk.WK, blk.WV, blk.WO, blk.Wm) def resid(idx): xin = blk.embed(idx).detach() zs = relax(blk, xin.clone(), xin, T1, eps) zn = relax(blk, zs, xin, 1, eps) return (zn - zs).norm().item() / (zs.norm().item() + 1e-9) def gcos(idx, y): gep = ep_step(blk, idx, y, T1, T2, eps, beta) gbp = bptt_step(blk, idx, y, T1, eps) fa, fb = [], [] for p in BP: a, b = gep.get(id(p)), gbp.get(id(p)) if a is not None and b is not None and torch.isfinite(a).all() and torch.isfinite(b).all(): fa.append(a.flatten()); fb.append(b.flatten()) if not fa: return float('nan'), gep return F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item(), gep print(f"{'step':>4} {'free_resid':>11} {'cos(EP,BPTT)':>13} {'val_CE':>8}") for step in range(1, 161): idx, y = get_batch('train', B, T) r = resid(idx) c, gep = gcos(idx, y) if step % 10 == 0 or step <= 5: with torch.no_grad(): vi, vy = get_batch('val', B, T) xin = blk.embed(vi).detach() v = ce(blk, relax(blk, xin.clone(), xin, T1, eps), vy).item() print(f"{step:>4} {r:>11.2e} {c:>13.3f} {v:>8.3f}", flush=True) # apply EP grads (the actual unstable training) if all((g is None) or torch.isfinite(g).all() for g in gep.values()): opt.zero_grad(set_to_none=True) for p in blk.allp: p.grad = gep.get(id(p)) torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) opt.step() else: print(f"{step:>4} NON-FINITE EP grad -> would skip", flush=True)