"""EP lr theory, step 1: measure k = |g_EP|/|g_BPTT| per param group at a realistic operating point. Native reference is BPTT (Ernoult: EP=BPTT as beta->0, converged) — NOT BP. lr_EP = lr_BPTT / k. Report magnitude ratio AND cosine (direction) per group so we separate scale (k) from alignment.""" import torch import lt_ep_train as M from pathlib import Path import pickle M.DD = Path('/tmp/lt_ep/data/tinystories_bpe'); M.vocab = pickle.load(open(M.DD/'meta.pkl','rb'))['vocab_size'] from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step torch.manual_seed(0) C,H,T,B = 512, 16, 256, 16 blk = EQBlock(C,H,256,T,attn_mode='thick'); blk.qknorm=True; blk.track=False; blk.li_avg=0; blk.navg=1; blk.fnoise=0; blk.nbrake=0; blk._cstep=None with torch.no_grad(): blk.WO.mul_(0.1); blk.pj.mul_(0.1) opt = torch.optim.AdamW(blk.allp, lr=5e-4, weight_decay=1e-4) for _ in range(300): # pretrain to a realistic operating point (BPTT) idx,y = get_batch('train',B,T); g = bptt_step(blk,idx,y,150,0.1) opt.zero_grad(set_to_none=True) for p in blk.allp: p.grad = g.get(id(p)) torch.nn.utils.clip_grad_norm_(blk.allp,5.0); opt.step() print("pretrained 300 BPTT steps (C=512). k=|g_EP|/|g_BPTT|, cos=direction:", flush=True) groups = {'all':blk.block,'attn':[blk.WQ,blk.WK,blk.WV,blk.WO],'ffn':[blk.fc,blk.fcb,blk.pj,blk.pjb], 'ln':[blk.ln1g,blk.ln1b,blk.ln2g,blk.ln2b],'emb':[blk.tok,blk.pos]} def cat(g,ps): v=[g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None]; return torch.cat(v) if v else None import numpy as np acc={k:[] for k in groups}; accc={k:[] for k in groups} for _ in range(6): idx,y = get_batch('train',B,T) gE,_ = ep_step(blk,idx,y,150,20,0.1,0.02,0.0,holo=2,hr=0.02,t1max=500,res_est=1e-4,t2sel=120) gB = bptt_step(blk,idx,y,400,0.1) for k,ps in groups.items(): a,b = cat(gE,ps),cat(gB,ps) if a is not None and b is not None: acc[k].append((a.norm()/(b.norm()+1e-12)).item()) accc[k].append((a@b/(a.norm()*b.norm()+1e-12)).item()) print(f"{'group':>5} {'k=|gEP|/|gBPTT|':>16} {'cos':>6} -> lr_EP = lr_BPTT / k") for k in groups: print(f"{k:>5} {np.mean(acc[k]):>16.3f} {np.mean(accc[k]):>6.3f}", flush=True)