diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/track_probe.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/track_probe.py')
| -rw-r--r-- | ep_run/track_probe.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/ep_run/track_probe.py b/ep_run/track_probe.py new file mode 100644 index 0000000..2846745 --- /dev/null +++ b/ep_run/track_probe.py @@ -0,0 +1,77 @@ +import torch, math +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories') +M.vocab = pickle.load(open(M.DD/'meta.pkl','rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, relax +from holo_ep import holo_a_select2, rforce, rgrad_ce +import torch.func as tf + +def holo_a_track(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0): + """Common-mode-tracking AEP: linearize the antisymmetric correction at the instantaneous + common mode of the two phases — exact transposed differential dynamics, loose-tolerant, + no compounding linearization error.""" + B = zs.size(0) + Z = torch.cat([zs, zs], 0) + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) + fnc = lambda zz: blk.nc_force(zz) + a_prev = a_best = None + inc_min, t_best = float('inf'), 0 + for t in range(1, T2max + 1): + with torch.no_grad(): + zbar = 0.5 * (Z[:B] + Z[B:]) + zb2 = torch.cat([zbar, zbar], 0) + f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) + v = (Z - zb2).contiguous() + _, Jv = tf.jvp(fnc, (zb2,), (v,)) + JTv = tf.vjp(fnc, zb2)[1](v)[0] + Z = Z + eps * (f - (Jv - JTv)) + if t % K == 0 or t == T2max: + a_t = (Z[B:] - Z[:B]) / (2 * r) + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + if a_best is None: + a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r) + t_best = T2max + return a_best.detach(), t_best + +if __name__ == '__main__': + torch.manual_seed(0) + B, T, C, H = 8, 256, 256, 8 + blk = EQBlock(C, H, 256, T, attn_mode='thick') + ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt') + for p, w in zip(blk.allp, ck['allp']): + with torch.no_grad(): + p.copy_(w.to('cuda')) + idx, y = get_batch('train', B, T) + xin = blk.embed(idx).detach() + ref150 = bptt_step(blk, idx, y, 150, 0.1) + def flat(g): + keep = [p for p in blk.block if g.get(id(p)) is not None] + return torch.cat([g[id(p)].reshape(-1) for p in keep]) + v150 = flat(ref150) + def gfrom(zs, a_): + with torch.enable_grad(): + x2 = blk.embed(idx) + f = blk.force(zs.detach(), x2, cg=True) + return {id(p): g for p, g in zip(blk.block, torch.autograd.grad((a_*f).sum(), blk.block, allow_unused=True))} + z = xin.clone(); prev = 0 + for T1 in (75, 150, 600): + z = relax(blk, z, xin, T1 - prev, 0.1); prev = T1 + res = (relax(blk, z, xin, 1, 0.1) - z).norm().item() / z.norm().item() + for name, fn in (('frozen', holo_a_select2), ('track', holo_a_track)): + for T2m in (120, 300): + a, tb = fn(blk, z, xin, y, 0.02, T2m, 0.1) + va = flat(gfrom(z, a)) + c = (va @ v150 / (va.norm() * v150.norm() + 1e-12)).item() + print(f"T1={T1:>4} res={res:.1e} {name:>6} T2max={T2m:>3}: t_best={tb:>3} cos_vs150={c:.3f}", flush=True) |
