diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/drift_diag.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/drift_diag.py')
| -rw-r--r-- | ep_run/drift_diag.py | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/ep_run/drift_diag.py b/ep_run/drift_diag.py new file mode 100644 index 0000000..4180104 --- /dev/null +++ b/ep_run/drift_diag.py @@ -0,0 +1,87 @@ +"""Late-drift diagnostic. Every stable EP/BPTT recipe peaks mid-run then val CE drifts up 0.1-0.3. +Train the champion recipe on S0 (fast) and log, every 200 steps, quantities that SEPARATE the +competing hypotheses: + - train_ce vs val_raw : train down + val up => OVERFIT ; both up => OPTIMIZATION INSTABILITY + - val_raw vs val_deep (T1=150 vs 400) : diverge => DYNAMICAL (fixed point degrading off train depth) + - res (free-phase) over time : climbing => contraction lost + - jr (lambda) over time : climbing => CONTROLLER FIGHT + - cos(EP-grad, BPTT-400) on a fixed probe batch : dropping => ESTIMATOR DEGRADATION + - |W|/|W_init| per group + cap-bind frac : group-specific => PARAMETRIC RUNAWAY + - ||ema - raw|| : how far the averaged weights sit from the wandering raw ones +""" +import math, time, torch +import lt_ep_train as M +from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax, evaluate, ce +dev = 'cuda' +torch.manual_seed(0) +B, T, C, H = 32, 64, 128, 4 +blk = EQBlock(C, H, 256, T, attn_mode='thick', c=1.0) +for w in blk.capw: + blk.caps[id(w)] = w.detach().norm().item() * 3.0 +opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) +STEPS = 9000 +sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=5e-5) +pema = [p.detach().clone() for p in blk.allp] +W0 = {id(p): p.detach().norm().item() for p in blk.allp} +groups = {'emb': [blk.tok, blk.pos], 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], + 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b]} +blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0 + +# fixed probe batch for gradient-cosine-over-training +pidx, py = get_batch('train', 16, T) +def grad_cos(): + ref = bptt_step(blk, pidx, py, 400, 0.1) + g, _ = ep_step(blk, pidx, py, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120) + keep = [p for p in blk.block if g.get(id(p)) is not None and ref.get(id(p)) is not None] + va = torch.cat([g[id(p)].reshape(-1) for p in keep]); vb = torch.cat([ref[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + +@torch.no_grad() +def deep_val(T1): + tot = 0.0 + for _ in range(8): + ix, yy = get_batch('val', 32, T) + xin = blk.embed(ix).detach() + z = relax(blk, xin.clone(), xin, T1, 0.1) + tot += ce(blk, z, yy).item() + return tot / 8 + +jr, rs = 1.0, None +print(f"{'step':>5} {'train':>6} {'val150':>6} {'val400':>6} {'ema150':>6} {'res':>8} {'jr':>5} " + f"{'cos':>5} {'|emb|':>5} {'|attn|':>6} {'|ffn|':>5} {'|ln|':>5} {'emaΔ':>6}", flush=True) +for step in range(1, STEPS + 1): + idx, y = get_batch('train', B, T) + grads, res = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, jr, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120, res_gate=5e-3) + flo = 0.1 + rs = res if rs is None else 0.9 * rs + 0.1 * res + jr = min(16.0, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / 1.5e-3)))) + if all((g is None) or torch.isfinite(g).all() for g in grads.values()): + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = grads.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) + opt.step(); sched.step() + with torch.no_grad(): + for p in blk.capw: + pn = p.norm(); cap = blk.caps[id(p)] + if pn > cap: + p.mul_(cap / pn) + for s, p in zip(pema, blk.allp): + s.mul_(0.999).add_(p.detach(), alpha=1e-3) + if step % 200 == 0: + with torch.no_grad(): + tc = ce(blk, relax(blk, blk.embed(idx).detach().clone(), blk.embed(idx).detach(), 150, 0.1), y).item() + v150, v400 = deep_val(150), deep_val(400) + raw = [p.detach().clone() for p in blk.allp] + with torch.no_grad(): + for p, s in zip(blk.allp, pema): + p.copy_(s) + ve = deep_val(150) + emad = math.sqrt(sum((r - s).pow(2).sum().item() for r, s in zip(raw, pema))) + for p, r in zip(blk.allp, raw): + p.copy_(r) + gn = {k: math.sqrt(sum(p.detach().norm().item()**2 for p in ps)) / + math.sqrt(sum(W0[id(p)]**2 for p in ps)) for k, ps in groups.items()} + cs = grad_cos() + print(f"{step:>5} {tc:>6.3f} {v150:>6.3f} {v400:>6.3f} {ve:>6.3f} {res:>8.1e} {jr:>5.1f} " + f"{cs:>5.2f} {gn['emb']:>5.2f} {gn['attn']:>6.2f} {gn['ffn']:>5.2f} {gn['ln']:>5.2f} {emad:>6.2f}", flush=True) |
