1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
|
"""Diagnostic: WHY does EP training destabilize? Test the hypothesis (Ernoult 2019):
EP == BPTT iff the free phase has CONVERGED (+ small beta). So log, per training step:
- free-phase residual ||Pi(z*+eF)-z*||/||z*|| (is the fixed point still there?)
- cosine(EP-grad, BPTT-grad) over the block params (is EP still tracking the true grad?)
If cosine starts ~1 and stays ~1 until the residual blows up -> it's loss of convergence, not beta.
"""
import math, torch, torch.nn.functional as F
from lt_ep_train import EQBlock, ep_step, bptt_step, relax, get_batch, ce
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
T1, T2, eps, beta, B, T = 80, 15, 0.1, 0.02, 32, 64
blk = EQBlock(128, 4, 256, T, s=1.0, c=1.0)
opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
BP = (blk.WQ, blk.WK, blk.WV, blk.WO, blk.Wm)
def resid(idx):
xin = blk.embed(idx).detach()
zs = relax(blk, xin.clone(), xin, T1, eps)
zn = relax(blk, zs, xin, 1, eps)
return (zn - zs).norm().item() / (zs.norm().item() + 1e-9)
def gcos(idx, y):
gep = ep_step(blk, idx, y, T1, T2, eps, beta)
gbp = bptt_step(blk, idx, y, T1, eps)
fa, fb = [], []
for p in BP:
a, b = gep.get(id(p)), gbp.get(id(p))
if a is not None and b is not None and torch.isfinite(a).all() and torch.isfinite(b).all():
fa.append(a.flatten()); fb.append(b.flatten())
if not fa:
return float('nan'), gep
return F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item(), gep
print(f"{'step':>4} {'free_resid':>11} {'cos(EP,BPTT)':>13} {'val_CE':>8}")
for step in range(1, 161):
idx, y = get_batch('train', B, T)
r = resid(idx)
c, gep = gcos(idx, y)
if step % 10 == 0 or step <= 5:
with torch.no_grad():
vi, vy = get_batch('val', B, T)
xin = blk.embed(vi).detach()
v = ce(blk, relax(blk, xin.clone(), xin, T1, eps), vy).item()
print(f"{step:>4} {r:>11.2e} {c:>13.3f} {v:>8.3f}", flush=True)
# apply EP grads (the actual unstable training)
if all((g is None) or torch.isfinite(g).all() for g in gep.values()):
opt.zero_grad(set_to_none=True)
for p in blk.allp:
p.grad = gep.get(id(p))
torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
opt.step()
else:
print(f"{step:>4} NON-FINITE EP grad -> would skip", flush=True)
|