summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_diag.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/lt_ep_diag.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/lt_ep_diag.py')
-rw-r--r--ep_run/lt_ep_diag.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/ep_run/lt_ep_diag.py b/ep_run/lt_ep_diag.py
new file mode 100644
index 0000000..1cec665
--- /dev/null
+++ b/ep_run/lt_ep_diag.py
@@ -0,0 +1,57 @@
+"""Diagnostic: WHY does EP training destabilize? Test the hypothesis (Ernoult 2019):
+EP == BPTT iff the free phase has CONVERGED (+ small beta). So log, per training step:
+ - free-phase residual ||Pi(z*+eF)-z*||/||z*|| (is the fixed point still there?)
+ - cosine(EP-grad, BPTT-grad) over the block params (is EP still tracking the true grad?)
+If cosine starts ~1 and stays ~1 until the residual blows up -> it's loss of convergence, not beta.
+"""
+import math, torch, torch.nn.functional as F
+from lt_ep_train import EQBlock, ep_step, bptt_step, relax, get_batch, ce
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+T1, T2, eps, beta, B, T = 80, 15, 0.1, 0.02, 32, 64
+blk = EQBlock(128, 4, 256, T, s=1.0, c=1.0)
+opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
+BP = (blk.WQ, blk.WK, blk.WV, blk.WO, blk.Wm)
+
+
+def resid(idx):
+ xin = blk.embed(idx).detach()
+ zs = relax(blk, xin.clone(), xin, T1, eps)
+ zn = relax(blk, zs, xin, 1, eps)
+ return (zn - zs).norm().item() / (zs.norm().item() + 1e-9)
+
+
+def gcos(idx, y):
+ gep = ep_step(blk, idx, y, T1, T2, eps, beta)
+ gbp = bptt_step(blk, idx, y, T1, eps)
+ fa, fb = [], []
+ for p in BP:
+ a, b = gep.get(id(p)), gbp.get(id(p))
+ if a is not None and b is not None and torch.isfinite(a).all() and torch.isfinite(b).all():
+ fa.append(a.flatten()); fb.append(b.flatten())
+ if not fa:
+ return float('nan'), gep
+ return F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item(), gep
+
+
+print(f"{'step':>4} {'free_resid':>11} {'cos(EP,BPTT)':>13} {'val_CE':>8}")
+for step in range(1, 161):
+ idx, y = get_batch('train', B, T)
+ r = resid(idx)
+ c, gep = gcos(idx, y)
+ if step % 10 == 0 or step <= 5:
+ with torch.no_grad():
+ vi, vy = get_batch('val', B, T)
+ xin = blk.embed(vi).detach()
+ v = ce(blk, relax(blk, xin.clone(), xin, T1, eps), vy).item()
+ print(f"{step:>4} {r:>11.2e} {c:>13.3f} {v:>8.3f}", flush=True)
+ # apply EP grads (the actual unstable training)
+ if all((g is None) or torch.isfinite(g).all() for g in gep.values()):
+ opt.zero_grad(set_to_none=True)
+ for p in blk.allp:
+ p.grad = gep.get(id(p))
+ torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
+ opt.step()
+ else:
+ print(f"{step:>4} NON-FINITE EP grad -> would skip", flush=True)