diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/lt_ep_diag.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/lt_ep_diag.py')
| -rw-r--r-- | ep_run/lt_ep_diag.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/ep_run/lt_ep_diag.py b/ep_run/lt_ep_diag.py new file mode 100644 index 0000000..1cec665 --- /dev/null +++ b/ep_run/lt_ep_diag.py @@ -0,0 +1,57 @@ +"""Diagnostic: WHY does EP training destabilize? Test the hypothesis (Ernoult 2019): +EP == BPTT iff the free phase has CONVERGED (+ small beta). So log, per training step: + - free-phase residual ||Pi(z*+eF)-z*||/||z*|| (is the fixed point still there?) + - cosine(EP-grad, BPTT-grad) over the block params (is EP still tracking the true grad?) +If cosine starts ~1 and stays ~1 until the residual blows up -> it's loss of convergence, not beta. +""" +import math, torch, torch.nn.functional as F +from lt_ep_train import EQBlock, ep_step, bptt_step, relax, get_batch, ce + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +T1, T2, eps, beta, B, T = 80, 15, 0.1, 0.02, 32, 64 +blk = EQBlock(128, 4, 256, T, s=1.0, c=1.0) +opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) +BP = (blk.WQ, blk.WK, blk.WV, blk.WO, blk.Wm) + + +def resid(idx): + xin = blk.embed(idx).detach() + zs = relax(blk, xin.clone(), xin, T1, eps) + zn = relax(blk, zs, xin, 1, eps) + return (zn - zs).norm().item() / (zs.norm().item() + 1e-9) + + +def gcos(idx, y): + gep = ep_step(blk, idx, y, T1, T2, eps, beta) + gbp = bptt_step(blk, idx, y, T1, eps) + fa, fb = [], [] + for p in BP: + a, b = gep.get(id(p)), gbp.get(id(p)) + if a is not None and b is not None and torch.isfinite(a).all() and torch.isfinite(b).all(): + fa.append(a.flatten()); fb.append(b.flatten()) + if not fa: + return float('nan'), gep + return F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item(), gep + + +print(f"{'step':>4} {'free_resid':>11} {'cos(EP,BPTT)':>13} {'val_CE':>8}") +for step in range(1, 161): + idx, y = get_batch('train', B, T) + r = resid(idx) + c, gep = gcos(idx, y) + if step % 10 == 0 or step <= 5: + with torch.no_grad(): + vi, vy = get_batch('val', B, T) + xin = blk.embed(vi).detach() + v = ce(blk, relax(blk, xin.clone(), xin, T1, eps), vy).item() + print(f"{step:>4} {r:>11.2e} {c:>13.3f} {v:>8.3f}", flush=True) + # apply EP grads (the actual unstable training) + if all((g is None) or torch.isfinite(g).all() for g in gep.values()): + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = gep.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) + opt.step() + else: + print(f"{step:>4} NON-FINITE EP grad -> would skip", flush=True) |
