diff options
Diffstat (limited to 'ep_run/adaptive_eps_calib2.py')
| -rw-r--r-- | ep_run/adaptive_eps_calib2.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/ep_run/adaptive_eps_calib2.py b/ep_run/adaptive_eps_calib2.py new file mode 100644 index 0000000..ab78f7e --- /dev/null +++ b/ep_run/adaptive_eps_calib2.py @@ -0,0 +1,39 @@ +import torch, pickle +from pathlib import Path +import lt_ep_train as L +from lt_ep_train import EQBlock +L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size'] +dev='cuda'; B=8; T=256 +torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx +CK={'s3200':torch.load('runs/redx_traj/s3200.pt',map_location=dev), + 's2000':torch.load('runs/redx_traj/s2000.pt',map_location=dev)} +def mkblk(name): + blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True + with torch.no_grad(): + for p,w in zip(blk.allp,CK[name]['allp']): p.copy_(w.to(dev)) + return blk +def fg(blk,z,xin): + f=blk.force(z,xin).detach(); return f, f.norm().item()/(z.norm().item()+1e-9) +# corrected controller: shrink on OVERSHOOT (g rose), grow otherwise +def run(name, e0=0.05, emin=0.003, emax=0.1, up=1.05, down=0.7, tol=1.0, N=8000): + blk=mkblk(name) + with torch.no_grad(): + xin=blk.embed(idx).detach(); z=xin.clone(); eps=e0; prev=None; gs=[]; eh=[] + for t in range(N): + f,g=fg(blk,z,xin); gs.append(g); eh.append(eps) + if prev is not None: + if g > prev*tol: eps=max(emin, eps*down) # residual climbed -> eps too big + else: eps=min(emax, eps*up) # contracting -> grow for speed + prev=g; z=z+eps*f + tail=gs[-500:] + return dict(gmin=min(tail), gmean=sum(tail)/len(tail), avg_eps=sum(eh)/len(eh), final_eps=eh[-1]) +print("=== corrected adaptive-eps (shrink on OVERSHOOT) — calibrate on stiff + smooth ===") +print("target: s3200 converges (g~0.09) at avg_eps>0.005 (faster than naive); s2000 stays eps~0.1") +for name in ('s3200','s2000'): + print(f"-- {name} --") + for tag,kw in [("A up1.05 dn0.7", dict(up=1.05,down=0.7)), + ("B up1.1 dn0.5", dict(up=1.1,down=0.5)), + ("C up1.03 dn0.8 tol1.02", dict(up=1.03,down=0.8,tol=1.02))]: + r=run(name,**kw) + print(f" {tag}: g_tail[min={r['gmin']:.4f} mean={r['gmean']:.4f}] avg_eps={r['avg_eps']:.4f} final_eps={r['final_eps']:.4f}") +print("=== DONE ===") |
