1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
"""Gradient-quality probe at a realistic operating point. Pretrain the thick block with BPTT
for 300 steps (lands near where good optima live, res ~1e-2-1e-3, NO contraction penalty),
then measure cosine(EP gradient, long-horizon-BPTT reference) per parameter group as a
function of free-phase length T1 (= residual level) and nudge length T2.
If cosine is high at res ~1e-2 -> there is NO estimator wall at the operating points that
matter; the EP-vs-BPTT gap is speed + regularization tax, not a convergence mechanism."""
import torch
from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
B, T, C, H = 16, 64, 128, 4
blk = EQBlock(C, H, 256, T, attn_mode='thick') # c=1.0 default = thick-BPTT baseline setting
import os
if os.path.exists('/tmp/lt_ep/probe_w.pt'):
for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')):
with torch.no_grad():
p.copy_(w.to(dev))
print("loaded cached pretrained weights", flush=True)
else:
opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
for step in range(300):
idx, y = get_batch('train', B, T)
g = bptt_step(blk, idx, y, 150, 0.1)
opt.zero_grad(set_to_none=True)
for p in blk.allp:
p.grad = g.get(id(p))
torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
opt.step()
torch.save([p.detach().cpu() for p in blk.allp], '/tmp/lt_ep/probe_w.pt')
print("pretrained 300 BPTT steps (thick, c=1) -- measuring at this operating point", flush=True)
groups = {'all': blk.block,
'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb],
'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b],
'emb': [blk.tok, blk.pos],
'mem': [blk.Wm]}
def cos(ga, gb, ps):
keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None]
if not keep:
na = sum(1 for p in ps if ga.get(id(p)) is None)
nb = sum(1 for p in ps if gb.get(id(p)) is None)
print(f"\n[debug] empty keep: |ps|={len(ps)} ga_None={na} gb_None={nb} "
f"ga_keys={len(ga)} gb_keys={len(gb)}", flush=True)
return float('nan')
va = torch.cat([ga[id(p)].reshape(-1) for p in keep])
vb = torch.cat([gb[id(p)].reshape(-1) for p in keep])
return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()
hdr = f"{'config':>16} {'res':>9} " + " ".join(f"{k:>6}" for k in groups)
for bi in range(3):
idx, y = get_batch('train', B, T)
ref = bptt_step(blk, idx, y, 400, 0.1) # exact reference: long-horizon BPTT
print(("\n" if bi else "") + hdr, flush=True)
g150 = bptt_step(blk, idx, y, 150, 0.1)
print(f"{'bptt T1=150':>16} {'--':>9} " + " ".join(f"{cos(g150, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
for T1 in (50, 150, 400):
gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0)
print(f"{f'ep T1={T1:<3} T2=20':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
gep, res = ep_step(blk, idx, y, 150, 60, 0.1, 0.02, 0.0)
print(f"{'ep T1=150 T2=60':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
|