summaryrefslogtreecommitdiff
path: root/ep_run/jnc_scaling.py
blob: 2126d9a8652376d6fc87e8f8ce8d90ef79a3fa4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Causal probe for the stability onset: measure the non-conservative Jacobian norm ‖J_nc(z*)‖
(Hutchinson, the same quantity the jacreg controller penalizes) vs width C, at init and after a
few training steps. If ‖J_nc‖ growth-per-lr-step crosses the contraction margin at the measured
critical scale, the lr_crit onset is DERIVED from dynamics, not a hyperparameter."""
import math, torch
import lt_ep_train as M
from pathlib import Path
import pickle
M.DD = Path('/tmp/lt_ep/data/tinystories_bpe')
M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size']
from lt_ep_train import EQBlock, get_batch, bptt_step, relax
dev = 'cuda'

def jnc_norm(blk, zs, n=8):                          # Hutchinson estimate of ‖J_nc(z*)‖_F
    tot = 0.0
    for _ in range(n):
        e = torch.randn_like(zs)
        Jv = torch.autograd.functional.jvp(blk.nc_force, zs, e)[1]
        tot += (Jv.pow(2).sum() / e.pow(2).sum()).item()
    return math.sqrt(tot / n)

print(f"{'C':>5} {'H':>3} {'init_res':>9} {'|Jnc|init':>10} {'|Jnc|@100':>10} {'growth/step':>11}")
for C in (256, 512, 768, 1024, 1536, 2048):
    torch.manual_seed(0)
    H = C // 32
    blk = EQBlock(C, H, 256, 256, attn_mode='thick')
    blk.qknorm = True; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0; blk.nbrake = 0; blk._cstep = None
    with torch.no_grad():
        blk.WO.mul_(0.1); blk.pj.mul_(0.1)           # resinit 0.1 (match the sweep)
    idx, y = get_batch('train', 8, 256)
    xin = blk.embed(idx).detach()
    zs = relax(blk, xin.clone(), xin, 150, 0.1)
    res0 = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item()
    j0 = jnc_norm(blk, zs)
    # 100 BPTT steps at a FIXED lr to see how fast ‖J_nc‖ grows (the destabilizing drive)
    opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
    for _ in range(100):
        ix, yy = get_batch('train', 8, 256)
        g = bptt_step(blk, ix, yy, 150, 0.1)
        opt.zero_grad(set_to_none=True)
        for p in blk.allp:
            p.grad = g.get(id(p))
        torch.nn.utils.clip_grad_norm_(blk.allp, 5.0); opt.step()
    zs2 = relax(blk, xin.clone(), xin, 150, 0.1)
    j1 = jnc_norm(blk, zs2)
    print(f"{C:>5} {H:>3} {res0:>9.1e} {j0:>10.3f} {j1:>10.3f} {(j1-j0)/100:>11.2e}", flush=True)