1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
"""EP lr theory, step 1: measure k = |g_EP|/|g_BPTT| per param group at a realistic operating point.
Native reference is BPTT (Ernoult: EP=BPTT as beta->0, converged) — NOT BP. lr_EP = lr_BPTT / k.
Report magnitude ratio AND cosine (direction) per group so we separate scale (k) from alignment."""
import torch
import lt_ep_train as M
from pathlib import Path
import pickle
M.DD = Path('/tmp/lt_ep/data/tinystories_bpe'); M.vocab = pickle.load(open(M.DD/'meta.pkl','rb'))['vocab_size']
from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step
torch.manual_seed(0)
C,H,T,B = 512, 16, 256, 16
blk = EQBlock(C,H,256,T,attn_mode='thick'); blk.qknorm=True; blk.track=False; blk.li_avg=0; blk.navg=1; blk.fnoise=0; blk.nbrake=0; blk._cstep=None
with torch.no_grad(): blk.WO.mul_(0.1); blk.pj.mul_(0.1)
opt = torch.optim.AdamW(blk.allp, lr=5e-4, weight_decay=1e-4)
for _ in range(300): # pretrain to a realistic operating point (BPTT)
idx,y = get_batch('train',B,T); g = bptt_step(blk,idx,y,150,0.1)
opt.zero_grad(set_to_none=True)
for p in blk.allp: p.grad = g.get(id(p))
torch.nn.utils.clip_grad_norm_(blk.allp,5.0); opt.step()
print("pretrained 300 BPTT steps (C=512). k=|g_EP|/|g_BPTT|, cos=direction:", flush=True)
groups = {'all':blk.block,'attn':[blk.WQ,blk.WK,blk.WV,blk.WO],'ffn':[blk.fc,blk.fcb,blk.pj,blk.pjb],
'ln':[blk.ln1g,blk.ln1b,blk.ln2g,blk.ln2b],'emb':[blk.tok,blk.pos]}
def cat(g,ps):
v=[g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None]; return torch.cat(v) if v else None
import numpy as np
acc={k:[] for k in groups}; accc={k:[] for k in groups}
for _ in range(6):
idx,y = get_batch('train',B,T)
gE,_ = ep_step(blk,idx,y,150,20,0.1,0.02,0.0,holo=2,hr=0.02,t1max=500,res_est=1e-4,t2sel=120)
gB = bptt_step(blk,idx,y,400,0.1)
for k,ps in groups.items():
a,b = cat(gE,ps),cat(gB,ps)
if a is not None and b is not None:
acc[k].append((a.norm()/(b.norm()+1e-12)).item())
accc[k].append((a@b/(a.norm()*b.norm()+1e-12)).item())
print(f"{'group':>5} {'k=|gEP|/|gBPTT|':>16} {'cos':>6} -> lr_EP = lr_BPTT / k")
for k in groups:
print(f"{k:>5} {np.mean(acc[k]):>16.3f} {np.mean(accc[k]):>6.3f}", flush=True)
|