diff options
Diffstat (limited to 'ep_run/gcalib.py')
| -rw-r--r-- | ep_run/gcalib.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/ep_run/gcalib.py b/ep_run/gcalib.py new file mode 100644 index 0000000..85495c0 --- /dev/null +++ b/ep_run/gcalib.py @@ -0,0 +1,38 @@ +"""EP lr theory, step 1: measure k = |g_EP|/|g_BPTT| per param group at a realistic operating point. +Native reference is BPTT (Ernoult: EP=BPTT as beta->0, converged) — NOT BP. lr_EP = lr_BPTT / k. +Report magnitude ratio AND cosine (direction) per group so we separate scale (k) from alignment.""" +import torch +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories_bpe'); M.vocab = pickle.load(open(M.DD/'meta.pkl','rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step +torch.manual_seed(0) +C,H,T,B = 512, 16, 256, 16 +blk = EQBlock(C,H,256,T,attn_mode='thick'); blk.qknorm=True; blk.track=False; blk.li_avg=0; blk.navg=1; blk.fnoise=0; blk.nbrake=0; blk._cstep=None +with torch.no_grad(): blk.WO.mul_(0.1); blk.pj.mul_(0.1) +opt = torch.optim.AdamW(blk.allp, lr=5e-4, weight_decay=1e-4) +for _ in range(300): # pretrain to a realistic operating point (BPTT) + idx,y = get_batch('train',B,T); g = bptt_step(blk,idx,y,150,0.1) + opt.zero_grad(set_to_none=True) + for p in blk.allp: p.grad = g.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp,5.0); opt.step() +print("pretrained 300 BPTT steps (C=512). k=|g_EP|/|g_BPTT|, cos=direction:", flush=True) +groups = {'all':blk.block,'attn':[blk.WQ,blk.WK,blk.WV,blk.WO],'ffn':[blk.fc,blk.fcb,blk.pj,blk.pjb], + 'ln':[blk.ln1g,blk.ln1b,blk.ln2g,blk.ln2b],'emb':[blk.tok,blk.pos]} +def cat(g,ps): + v=[g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None]; return torch.cat(v) if v else None +import numpy as np +acc={k:[] for k in groups}; accc={k:[] for k in groups} +for _ in range(6): + idx,y = get_batch('train',B,T) + gE,_ = ep_step(blk,idx,y,150,20,0.1,0.02,0.0,holo=2,hr=0.02,t1max=500,res_est=1e-4,t2sel=120) + gB = bptt_step(blk,idx,y,400,0.1) + for k,ps in groups.items(): + a,b = cat(gE,ps),cat(gB,ps) + if a is not None and b is not None: + acc[k].append((a.norm()/(b.norm()+1e-12)).item()) + accc[k].append((a@b/(a.norm()*b.norm()+1e-12)).item()) +print(f"{'group':>5} {'k=|gEP|/|gBPTT|':>16} {'cos':>6} -> lr_EP = lr_BPTT / k") +for k in groups: + print(f"{k:>5} {np.mean(acc[k]):>16.3f} {np.mean(accc[k]):>6.3f}", flush=True) |
