summaryrefslogtreecommitdiff
path: root/ep_run/drift_diag.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/drift_diag.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/drift_diag.py')
-rw-r--r--ep_run/drift_diag.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/ep_run/drift_diag.py b/ep_run/drift_diag.py
new file mode 100644
index 0000000..4180104
--- /dev/null
+++ b/ep_run/drift_diag.py
@@ -0,0 +1,87 @@
+"""Late-drift diagnostic. Every stable EP/BPTT recipe peaks mid-run then val CE drifts up 0.1-0.3.
+Train the champion recipe on S0 (fast) and log, every 200 steps, quantities that SEPARATE the
+competing hypotheses:
+ - train_ce vs val_raw : train down + val up => OVERFIT ; both up => OPTIMIZATION INSTABILITY
+ - val_raw vs val_deep (T1=150 vs 400) : diverge => DYNAMICAL (fixed point degrading off train depth)
+ - res (free-phase) over time : climbing => contraction lost
+ - jr (lambda) over time : climbing => CONTROLLER FIGHT
+ - cos(EP-grad, BPTT-400) on a fixed probe batch : dropping => ESTIMATOR DEGRADATION
+ - |W|/|W_init| per group + cap-bind frac : group-specific => PARAMETRIC RUNAWAY
+ - ||ema - raw|| : how far the averaged weights sit from the wandering raw ones
+"""
+import math, time, torch
+import lt_ep_train as M
+from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax, evaluate, ce
+dev = 'cuda'
+torch.manual_seed(0)
+B, T, C, H = 32, 64, 128, 4
+blk = EQBlock(C, H, 256, T, attn_mode='thick', c=1.0)
+for w in blk.capw:
+ blk.caps[id(w)] = w.detach().norm().item() * 3.0
+opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
+STEPS = 9000
+sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=5e-5)
+pema = [p.detach().clone() for p in blk.allp]
+W0 = {id(p): p.detach().norm().item() for p in blk.allp}
+groups = {'emb': [blk.tok, blk.pos], 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
+ 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b]}
+blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0
+
+# fixed probe batch for gradient-cosine-over-training
+pidx, py = get_batch('train', 16, T)
+def grad_cos():
+ ref = bptt_step(blk, pidx, py, 400, 0.1)
+ g, _ = ep_step(blk, pidx, py, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120)
+ keep = [p for p in blk.block if g.get(id(p)) is not None and ref.get(id(p)) is not None]
+ va = torch.cat([g[id(p)].reshape(-1) for p in keep]); vb = torch.cat([ref[id(p)].reshape(-1) for p in keep])
+ return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()
+
+@torch.no_grad()
+def deep_val(T1):
+ tot = 0.0
+ for _ in range(8):
+ ix, yy = get_batch('val', 32, T)
+ xin = blk.embed(ix).detach()
+ z = relax(blk, xin.clone(), xin, T1, 0.1)
+ tot += ce(blk, z, yy).item()
+ return tot / 8
+
+jr, rs = 1.0, None
+print(f"{'step':>5} {'train':>6} {'val150':>6} {'val400':>6} {'ema150':>6} {'res':>8} {'jr':>5} "
+ f"{'cos':>5} {'|emb|':>5} {'|attn|':>6} {'|ffn|':>5} {'|ln|':>5} {'emaΔ':>6}", flush=True)
+for step in range(1, STEPS + 1):
+ idx, y = get_batch('train', B, T)
+ grads, res = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, jr, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120, res_gate=5e-3)
+ flo = 0.1
+ rs = res if rs is None else 0.9 * rs + 0.1 * res
+ jr = min(16.0, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / 1.5e-3))))
+ if all((g is None) or torch.isfinite(g).all() for g in grads.values()):
+ opt.zero_grad(set_to_none=True)
+ for p in blk.allp:
+ p.grad = grads.get(id(p))
+ torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
+ opt.step(); sched.step()
+ with torch.no_grad():
+ for p in blk.capw:
+ pn = p.norm(); cap = blk.caps[id(p)]
+ if pn > cap:
+ p.mul_(cap / pn)
+ for s, p in zip(pema, blk.allp):
+ s.mul_(0.999).add_(p.detach(), alpha=1e-3)
+ if step % 200 == 0:
+ with torch.no_grad():
+ tc = ce(blk, relax(blk, blk.embed(idx).detach().clone(), blk.embed(idx).detach(), 150, 0.1), y).item()
+ v150, v400 = deep_val(150), deep_val(400)
+ raw = [p.detach().clone() for p in blk.allp]
+ with torch.no_grad():
+ for p, s in zip(blk.allp, pema):
+ p.copy_(s)
+ ve = deep_val(150)
+ emad = math.sqrt(sum((r - s).pow(2).sum().item() for r, s in zip(raw, pema)))
+ for p, r in zip(blk.allp, raw):
+ p.copy_(r)
+ gn = {k: math.sqrt(sum(p.detach().norm().item()**2 for p in ps)) /
+ math.sqrt(sum(W0[id(p)]**2 for p in ps)) for k, ps in groups.items()}
+ cs = grad_cos()
+ print(f"{step:>5} {tc:>6.3f} {v150:>6.3f} {v400:>6.3f} {ve:>6.3f} {res:>8.1e} {jr:>5.1f} "
+ f"{cs:>5.2f} {gn['emb']:>5.2f} {gn['attn']:>6.2f} {gn['ffn']:>5.2f} {gn['ln']:>5.2f} {emad:>6.2f}", flush=True)