summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_diag.py
blob: 1cec6659396e1ec97d516a99b1b61d8afe44ad75 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Diagnostic: WHY does EP training destabilize? Test the hypothesis (Ernoult 2019):
EP == BPTT iff the free phase has CONVERGED (+ small beta). So log, per training step:
  - free-phase residual ||Pi(z*+eF)-z*||/||z*||   (is the fixed point still there?)
  - cosine(EP-grad, BPTT-grad) over the block params (is EP still tracking the true grad?)
If cosine starts ~1 and stays ~1 until the residual blows up -> it's loss of convergence, not beta.
"""
import math, torch, torch.nn.functional as F
from lt_ep_train import EQBlock, ep_step, bptt_step, relax, get_batch, ce

dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
T1, T2, eps, beta, B, T = 80, 15, 0.1, 0.02, 32, 64
blk = EQBlock(128, 4, 256, T, s=1.0, c=1.0)
opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
BP = (blk.WQ, blk.WK, blk.WV, blk.WO, blk.Wm)


def resid(idx):
    xin = blk.embed(idx).detach()
    zs = relax(blk, xin.clone(), xin, T1, eps)
    zn = relax(blk, zs, xin, 1, eps)
    return (zn - zs).norm().item() / (zs.norm().item() + 1e-9)


def gcos(idx, y):
    gep = ep_step(blk, idx, y, T1, T2, eps, beta)
    gbp = bptt_step(blk, idx, y, T1, eps)
    fa, fb = [], []
    for p in BP:
        a, b = gep.get(id(p)), gbp.get(id(p))
        if a is not None and b is not None and torch.isfinite(a).all() and torch.isfinite(b).all():
            fa.append(a.flatten()); fb.append(b.flatten())
    if not fa:
        return float('nan'), gep
    return F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item(), gep


print(f"{'step':>4} {'free_resid':>11} {'cos(EP,BPTT)':>13} {'val_CE':>8}")
for step in range(1, 161):
    idx, y = get_batch('train', B, T)
    r = resid(idx)
    c, gep = gcos(idx, y)
    if step % 10 == 0 or step <= 5:
        with torch.no_grad():
            vi, vy = get_batch('val', B, T)
            xin = blk.embed(vi).detach()
            v = ce(blk, relax(blk, xin.clone(), xin, T1, eps), vy).item()
        print(f"{step:>4} {r:>11.2e} {c:>13.3f} {v:>8.3f}", flush=True)
    # apply EP grads (the actual unstable training)
    if all((g is None) or torch.isfinite(g).all() for g in gep.values()):
        opt.zero_grad(set_to_none=True)
        for p in blk.allp:
            p.grad = gep.get(id(p))
        torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
        opt.step()
    else:
        print(f"{step:>4}  NON-FINITE EP grad -> would skip", flush=True)