diff options
Diffstat (limited to 'ep_run/jnc_scaling.py')
| -rw-r--r-- | ep_run/jnc_scaling.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/ep_run/jnc_scaling.py b/ep_run/jnc_scaling.py new file mode 100644 index 0000000..2126d9a --- /dev/null +++ b/ep_run/jnc_scaling.py @@ -0,0 +1,46 @@ +"""Causal probe for the stability onset: measure the non-conservative Jacobian norm ‖J_nc(z*)‖ +(Hutchinson, the same quantity the jacreg controller penalizes) vs width C, at init and after a +few training steps. If ‖J_nc‖ growth-per-lr-step crosses the contraction margin at the measured +critical scale, the lr_crit onset is DERIVED from dynamics, not a hyperparameter.""" +import math, torch +import lt_ep_train as M +from pathlib import Path +import pickle +M.DD = Path('/tmp/lt_ep/data/tinystories_bpe') +M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] +from lt_ep_train import EQBlock, get_batch, bptt_step, relax +dev = 'cuda' + +def jnc_norm(blk, zs, n=8): # Hutchinson estimate of ‖J_nc(z*)‖_F + tot = 0.0 + for _ in range(n): + e = torch.randn_like(zs) + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, e)[1] + tot += (Jv.pow(2).sum() / e.pow(2).sum()).item() + return math.sqrt(tot / n) + +print(f"{'C':>5} {'H':>3} {'init_res':>9} {'|Jnc|init':>10} {'|Jnc|@100':>10} {'growth/step':>11}") +for C in (256, 512, 768, 1024, 1536, 2048): + torch.manual_seed(0) + H = C // 32 + blk = EQBlock(C, H, 256, 256, attn_mode='thick') + blk.qknorm = True; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0; blk.nbrake = 0; blk._cstep = None + with torch.no_grad(): + blk.WO.mul_(0.1); blk.pj.mul_(0.1) # resinit 0.1 (match the sweep) + idx, y = get_batch('train', 8, 256) + xin = blk.embed(idx).detach() + zs = relax(blk, xin.clone(), xin, 150, 0.1) + res0 = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item() + j0 = jnc_norm(blk, zs) + # 100 BPTT steps at a FIXED lr to see how fast ‖J_nc‖ grows (the destabilizing drive) + opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) + for _ in range(100): + ix, yy = get_batch('train', 8, 256) + g = bptt_step(blk, ix, yy, 150, 0.1) + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = g.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0); opt.step() + zs2 = relax(blk, xin.clone(), xin, 150, 0.1) + j1 = jnc_norm(blk, zs2) + print(f"{C:>5} {H:>3} {res0:>9.1e} {j0:>10.3f} {j1:>10.3f} {(j1-j0)/100:>11.2e}", flush=True) |
