1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
"""Decisive 'why is EP far at S1' diagnostic: separate estimator BIAS from VARIANCE at the
converged v4b checkpoint. Over N batches compute EP grad, BPTT-400 grad, BPTT-150 control.
mean-cos = mean_b cos(g_EP^b, g_BPTT^b) -> per-step quality (noisy)
cos-means = cos(sum_b g_EP, sum_b g_BPTT) -> if >> mean-cos: errors AVERAGE OUT = VARIANCE
if ~ mean-cos: systematic = BIAS (the real wall)
BPTT-150-vs-400 gives the same two metrics as the slow-mixing horizon baseline."""
import torch
import lt_ep_train as M
from pathlib import Path
import pickle
M.DD = Path('/tmp/lt_ep/data/tinystories')
M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size']
from lt_ep_train import EQBlock, get_batch, bptt_step, ep_step
dev = 'cuda'
torch.manual_seed(0)
B, T, C, H = 8, 256, 256, 8
blk = EQBlock(C, H, 256, T, attn_mode='thick')
blk.qknorm = False; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0; blk._cstep = None
ck = torch.load('/tmp/lt_ep/ts_s1_ep_v4b.pt')
for p, w in zip(blk.allp, ck['allp']):
with torch.no_grad():
p.copy_(w.to(dev))
print(f"v4b ckpt best {ck['best']:.4f}", flush=True)
groups = {'all': blk.block, 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b],
'emb': [blk.tok, blk.pos]}
N = 16
sEP, s400, s150 = {}, {}, {}
cos_b = {k: [] for k in groups}
bctl_b = {k: [] for k in groups}
def flat(g, ps):
v = [g[id(p)].reshape(-1) for p in ps if g.get(id(p)) is not None]
return torch.cat(v) if v else None
def cos(a, b):
return (a @ b / (a.norm() * b.norm() + 1e-12)).item()
for i in range(N):
idx, y = get_batch('train', B, T)
gE, _ = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120)
g4 = bptt_step(blk, idx, y, 400, 0.1)
g1 = bptt_step(blk, idx, y, 150, 0.1)
for k, ps in groups.items():
a, b, c = flat(gE, ps), flat(g4, ps), flat(g1, ps)
if a is not None and b is not None:
cos_b[k].append(cos(a, b))
if c is not None and b is not None:
bctl_b[k].append(cos(c, b))
for src, acc in ((gE, sEP), (g4, s400), (g1, s150)):
for p in blk.block:
if src.get(id(p)) is not None:
acc[id(p)] = src[id(p)].detach().clone() if id(p) not in acc else acc[id(p)] + src[id(p)].detach()
print(f" batch {i+1}/{N} done", flush=True)
print(f"\n{'group':>5} {'EP mean-cos':>12} {'EP cos-means':>13} {'BPTT mean-cos':>14} {'BPTT cos-means':>15}")
for k, ps in groups.items():
mc = sum(cos_b[k]) / len(cos_b[k])
bmc = sum(bctl_b[k]) / len(bctl_b[k])
aE, a4, a1 = flat(sEP, ps), flat(s400, ps), flat(s150, ps)
cm = cos(aE, a4)
bcm = cos(a1, a4)
print(f"{k:>5} {mc:>12.3f} {cm:>13.3f} {bmc:>14.3f} {bcm:>15.3f}", flush=True)
|