summaryrefslogtreecommitdiff
path: root/ep_run/drift_diag.py
blob: 418010417273839808df10b5b505a27def85aa67 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Late-drift diagnostic. Every stable EP/BPTT recipe peaks mid-run then val CE drifts up 0.1-0.3.
Train the champion recipe on S0 (fast) and log, every 200 steps, quantities that SEPARATE the
competing hypotheses:
  - train_ce vs val_raw : train down + val up => OVERFIT ; both up => OPTIMIZATION INSTABILITY
  - val_raw vs val_deep (T1=150 vs 400) : diverge => DYNAMICAL (fixed point degrading off train depth)
  - res (free-phase) over time : climbing => contraction lost
  - jr (lambda) over time : climbing => CONTROLLER FIGHT
  - cos(EP-grad, BPTT-400) on a fixed probe batch : dropping => ESTIMATOR DEGRADATION
  - |W|/|W_init| per group + cap-bind frac : group-specific => PARAMETRIC RUNAWAY
  - ||ema - raw|| : how far the averaged weights sit from the wandering raw ones
"""
import math, time, torch
import lt_ep_train as M
from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax, evaluate, ce
dev = 'cuda'
torch.manual_seed(0)
B, T, C, H = 32, 64, 128, 4
blk = EQBlock(C, H, 256, T, attn_mode='thick', c=1.0)
for w in blk.capw:
    blk.caps[id(w)] = w.detach().norm().item() * 3.0
opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
STEPS = 9000
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, STEPS, eta_min=5e-5)
pema = [p.detach().clone() for p in blk.allp]
W0 = {id(p): p.detach().norm().item() for p in blk.allp}
groups = {'emb': [blk.tok, blk.pos], 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
          'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b]}
blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0.0; blk.nbrake = 0.0

# fixed probe batch for gradient-cosine-over-training
pidx, py = get_batch('train', 16, T)
def grad_cos():
    ref = bptt_step(blk, pidx, py, 400, 0.1)
    g, _ = ep_step(blk, pidx, py, 150, 20, 0.1, 0.02, 0.0, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120)
    keep = [p for p in blk.block if g.get(id(p)) is not None and ref.get(id(p)) is not None]
    va = torch.cat([g[id(p)].reshape(-1) for p in keep]); vb = torch.cat([ref[id(p)].reshape(-1) for p in keep])
    return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()

@torch.no_grad()
def deep_val(T1):
    tot = 0.0
    for _ in range(8):
        ix, yy = get_batch('val', 32, T)
        xin = blk.embed(ix).detach()
        z = relax(blk, xin.clone(), xin, T1, 0.1)
        tot += ce(blk, z, yy).item()
    return tot / 8

jr, rs = 1.0, None
print(f"{'step':>5} {'train':>6} {'val150':>6} {'val400':>6} {'ema150':>6} {'res':>8} {'jr':>5} "
      f"{'cos':>5} {'|emb|':>5} {'|attn|':>6} {'|ffn|':>5} {'|ln|':>5} {'emaΔ':>6}", flush=True)
for step in range(1, STEPS + 1):
    idx, y = get_batch('train', B, T)
    grads, res = ep_step(blk, idx, y, 150, 20, 0.1, 0.02, jr, holo=2, hr=0.02, t1max=500, res_est=1e-4, t2sel=120, res_gate=5e-3)
    flo = 0.1
    rs = res if rs is None else 0.9 * rs + 0.1 * res
    jr = min(16.0, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / 1.5e-3))))
    if all((g is None) or torch.isfinite(g).all() for g in grads.values()):
        opt.zero_grad(set_to_none=True)
        for p in blk.allp:
            p.grad = grads.get(id(p))
        torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
        opt.step(); sched.step()
        with torch.no_grad():
            for p in blk.capw:
                pn = p.norm(); cap = blk.caps[id(p)]
                if pn > cap:
                    p.mul_(cap / pn)
            for s, p in zip(pema, blk.allp):
                s.mul_(0.999).add_(p.detach(), alpha=1e-3)
    if step % 200 == 0:
        with torch.no_grad():
            tc = ce(blk, relax(blk, blk.embed(idx).detach().clone(), blk.embed(idx).detach(), 150, 0.1), y).item()
        v150, v400 = deep_val(150), deep_val(400)
        raw = [p.detach().clone() for p in blk.allp]
        with torch.no_grad():
            for p, s in zip(blk.allp, pema):
                p.copy_(s)
            ve = deep_val(150)
            emad = math.sqrt(sum((r - s).pow(2).sum().item() for r, s in zip(raw, pema)))
            for p, r in zip(blk.allp, raw):
                p.copy_(r)
        gn = {k: math.sqrt(sum(p.detach().norm().item()**2 for p in ps)) /
              math.sqrt(sum(W0[id(p)]**2 for p in ps)) for k, ps in groups.items()}
        cs = grad_cos()
        print(f"{step:>5} {tc:>6.3f} {v150:>6.3f} {v400:>6.3f} {ve:>6.3f} {res:>8.1e} {jr:>5.1f} "
              f"{cs:>5.2f} {gn['emb']:>5.2f} {gn['attn']:>6.2f} {gn['ffn']:>5.2f} {gn['ln']:>5.2f} {emad:>6.2f}", flush=True)