summaryrefslogtreecommitdiff
path: root/ep_run/adaptive_eps_calib.py
blob: 475e296fba9da4522b357d3dbf138a21d1bc7929 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch, pickle, math
from pathlib import Path
import lt_ep_train as L
from lt_ep_train import EQBlock
L.DD=Path('data/tinystories_bpe'); L.vocab=pickle.load(open(L.DD/'meta.pkl','rb'))['vocab_size']
dev='cuda'; B=8; T=256
torch.manual_seed(1234); idx,y=L.get_batch('val',B,T); idx=idx.to(dev) if hasattr(idx,'to') else idx
ck=torch.load('runs/redx_traj/s3200.pt',map_location=dev)
def mkblk():
    blk=EQBlock(512,16,256,256,s=1.0,c=1.0,attn_mode='thick'); blk.qknorm=True
    with torch.no_grad():
        for p,w in zip(blk.allp,ck['allp']): p.copy_(w.to(dev))
    return blk
def force_g(blk,z,xin): 
    f=blk.force(z,xin).detach(); return f, f.norm().item()/(z.norm().item()+1e-9)
def run(adaptive, e0=0.1, emin=0.005, emax=0.1, down=0.5, up=1.1, theta=0.98, N=10000):
    blk=mkblk()
    with torch.no_grad():
        xin=blk.embed(idx).detach(); z=xin.clone(); eps=e0; prev=None; gs=[]; eh=[]
        for t in range(N):
            f,g=force_g(blk,z,xin); gs.append(g); eh.append(eps)
            if adaptive and prev is not None:
                if g>theta*prev: eps=max(emin,eps*down)
                elif g<0.9*prev: eps=min(emax,eps*up)
            prev=g; z=z+eps*f
    tail=gs[-500:]
    return dict(gmin=min(tail), gmean=sum(tail)/len(tail), avg_eps=sum(eh)/len(eh), final_eps=eh[-1])
print("=== adaptive-eps controller CALIBRATION on s3200 (cycling op) ===")
print("ground truth: fixed eps=0.1 cycles (g~0.23); fixed eps=0.01 converges (g~0.09)")
print("-- benchmarks (fixed eps) --")
for e in (0.1, 0.01):
    r=run(False, e0=e, emin=e, emax=e); print(f"  fixed eps={e}: g_tail[min={r['gmin']:.4f} mean={r['gmean']:.4f}]")
print("-- adaptive configs (want: g <= 0.01-benchmark, avg_eps as HIGH as possible = fewer effective steps) --")
for name,kw in [("C1 cons", dict(down=0.5,up=1.1,theta=0.98)),
                ("C2 mod",  dict(down=0.7,up=1.2,theta=0.98)),
                ("C3 caut", dict(down=0.5,up=1.05,theta=0.99)),
                ("C4 aggr", dict(down=0.6,up=1.3,theta=0.95))]:
    r=run(True, **kw); 
    print(f"  {name} {kw}: g_tail[min={r['gmin']:.4f} mean={r['gmean']:.4f}] avg_eps={r['avg_eps']:.4f} final_eps={r['final_eps']:.4f}")
print("=== DONE ===")