summaryrefslogtreecommitdiff
path: root/ep_run/grad_quality.py
blob: 095d764e25b6fc1dbef86232e8647fedcb6e5f82 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Gradient-quality probe at a realistic operating point. Pretrain the thick block with BPTT
for 300 steps (lands near where good optima live, res ~1e-2-1e-3, NO contraction penalty),
then measure cosine(EP gradient, long-horizon-BPTT reference) per parameter group as a
function of free-phase length T1 (= residual level) and nudge length T2.
If cosine is high at res ~1e-2 -> there is NO estimator wall at the operating points that
matter; the EP-vs-BPTT gap is speed + regularization tax, not a convergence mechanism."""
import torch
from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
B, T, C, H = 16, 64, 128, 4
blk = EQBlock(C, H, 256, T, attn_mode='thick')          # c=1.0 default = thick-BPTT baseline setting
import os
if os.path.exists('/tmp/lt_ep/probe_w.pt'):
    for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')):
        with torch.no_grad():
            p.copy_(w.to(dev))
    print("loaded cached pretrained weights", flush=True)
else:
    opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
    for step in range(300):
        idx, y = get_batch('train', B, T)
        g = bptt_step(blk, idx, y, 150, 0.1)
        opt.zero_grad(set_to_none=True)
        for p in blk.allp:
            p.grad = g.get(id(p))
        torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
        opt.step()
    torch.save([p.detach().cpu() for p in blk.allp], '/tmp/lt_ep/probe_w.pt')
print("pretrained 300 BPTT steps (thick, c=1) -- measuring at this operating point", flush=True)

groups = {'all': blk.block,
          'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
          'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb],
          'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b],
          'emb': [blk.tok, blk.pos],
          'mem': [blk.Wm]}


def cos(ga, gb, ps):
    keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None]
    if not keep:
        na = sum(1 for p in ps if ga.get(id(p)) is None)
        nb = sum(1 for p in ps if gb.get(id(p)) is None)
        print(f"\n[debug] empty keep: |ps|={len(ps)} ga_None={na} gb_None={nb} "
              f"ga_keys={len(ga)} gb_keys={len(gb)}", flush=True)
        return float('nan')
    va = torch.cat([ga[id(p)].reshape(-1) for p in keep])
    vb = torch.cat([gb[id(p)].reshape(-1) for p in keep])
    return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()


hdr = f"{'config':>16} {'res':>9} " + " ".join(f"{k:>6}" for k in groups)
for bi in range(3):
    idx, y = get_batch('train', B, T)
    ref = bptt_step(blk, idx, y, 400, 0.1)               # exact reference: long-horizon BPTT
    print(("\n" if bi else "") + hdr, flush=True)
    g150 = bptt_step(blk, idx, y, 150, 0.1)
    print(f"{'bptt T1=150':>16} {'--':>9} " + " ".join(f"{cos(g150, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
    for T1 in (50, 150, 400):
        gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0)
        print(f"{f'ep T1={T1:<3} T2=20':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
    gep, res = ep_step(blk, idx, y, 150, 60, 0.1, 0.02, 0.0)
    print(f"{'ep T1=150 T2=60':>16} {res:>9.1e} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)