diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/grad_quality.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/grad_quality.py')
| -rw-r--r-- | ep_run/grad_quality.py | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/ep_run/grad_quality.py b/ep_run/grad_quality.py new file mode 100644 index 0000000..095d764 --- /dev/null +++ b/ep_run/grad_quality.py @@ -0,0 +1,64 @@ +"""Gradient-quality probe at a realistic operating point. Pretrain the thick block with BPTT +for 300 steps (lands near where good optima live, res ~1e-2-1e-3, NO contraction penalty), +then measure cosine(EP gradient, long-horizon-BPTT reference) per parameter group as a +function of free-phase length T1 (= residual level) and nudge length T2. +If cosine is high at res ~1e-2 -> there is NO estimator wall at the operating points that +matter; the EP-vs-BPTT gap is speed + regularization tax, not a convergence mechanism.""" +import torch +from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +B, T, C, H = 16, 64, 128, 4 +blk = EQBlock(C, H, 256, T, attn_mode='thick') # c=1.0 default = thick-BPTT baseline setting +import os +if os.path.exists('/tmp/lt_ep/probe_w.pt'): + for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')): + with torch.no_grad(): + p.copy_(w.to(dev)) + print("loaded cached pretrained weights", flush=True) +else: + opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) + for step in range(300): + idx, y = get_batch('train', B, T) + g = bptt_step(blk, idx, y, 150, 0.1) + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = g.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) + opt.step() + torch.save([p.detach().cpu() for p in blk.allp], '/tmp/lt_ep/probe_w.pt') +print("pretrained 300 BPTT steps (thick, c=1) -- measuring at this operating point", flush=True) + +groups = {'all': blk.block, + 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], + 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], + 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], + 'emb': [blk.tok, blk.pos], + 'mem': [blk.Wm]} + + +def cos(ga, gb, ps): + keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] + if not keep: + na = sum(1 for p in ps if ga.get(id(p)) is None) + nb = sum(1 for p in ps if gb.get(id(p)) is None) + print(f"\n[debug] empty keep: |ps|={len(ps)} ga_None={na} gb_None={nb} " + f"ga_keys={len(ga)} gb_keys={len(gb)}", flush=True) + return float('nan') + va = torch.cat([ga[id(p)].reshape(-1) for p in keep]) + vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + + +hdr = f"{'config':>16} {'res':>9} " + " ".join(f"{k:>6}" for k in groups) +for bi in range(3): + idx, y = get_batch('train', B, T) + ref = bptt_step(blk, idx, y, 400, 0.1) # exact reference: long-horizon BPTT + print(("\n" if bi else "") + hdr, flush=True) + g150 = bptt_step(blk, idx, y, 150, 0.1) + print(f"{'bptt T1=150':>16} {'--':>9} " + " ".join(f"{cos(g150, ref, ps):>6.3f}" for ps in groups.values()), flush=True) + for T1 in (50, 150, 400): + gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0) + print(f"{f'ep T1={T1:<3} T2=20':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) + gep, res = ep_step(blk, idx, y, 150, 60, 0.1, 0.02, 0.0) + print(f"{'ep T1=150 T2=60':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) |
