"""Causal probe for the stability onset: measure the non-conservative Jacobian norm ‖J_nc(z*)‖ (Hutchinson, the same quantity the jacreg controller penalizes) vs width C, at init and after a few training steps. If ‖J_nc‖ growth-per-lr-step crosses the contraction margin at the measured critical scale, the lr_crit onset is DERIVED from dynamics, not a hyperparameter.""" import math, torch import lt_ep_train as M from pathlib import Path import pickle M.DD = Path('/tmp/lt_ep/data/tinystories_bpe') M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size'] from lt_ep_train import EQBlock, get_batch, bptt_step, relax dev = 'cuda' def jnc_norm(blk, zs, n=8): # Hutchinson estimate of ‖J_nc(z*)‖_F tot = 0.0 for _ in range(n): e = torch.randn_like(zs) Jv = torch.autograd.functional.jvp(blk.nc_force, zs, e)[1] tot += (Jv.pow(2).sum() / e.pow(2).sum()).item() return math.sqrt(tot / n) print(f"{'C':>5} {'H':>3} {'init_res':>9} {'|Jnc|init':>10} {'|Jnc|@100':>10} {'growth/step':>11}") for C in (256, 512, 768, 1024, 1536, 2048): torch.manual_seed(0) H = C // 32 blk = EQBlock(C, H, 256, 256, attn_mode='thick') blk.qknorm = True; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0; blk.nbrake = 0; blk._cstep = None with torch.no_grad(): blk.WO.mul_(0.1); blk.pj.mul_(0.1) # resinit 0.1 (match the sweep) idx, y = get_batch('train', 8, 256) xin = blk.embed(idx).detach() zs = relax(blk, xin.clone(), xin, 150, 0.1) res0 = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item() j0 = jnc_norm(blk, zs) # 100 BPTT steps at a FIXED lr to see how fast ‖J_nc‖ grows (the destabilizing drive) opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4) for _ in range(100): ix, yy = get_batch('train', 8, 256) g = bptt_step(blk, ix, yy, 150, 0.1) opt.zero_grad(set_to_none=True) for p in blk.allp: p.grad = g.get(id(p)) torch.nn.utils.clip_grad_norm_(blk.allp, 5.0); opt.step() zs2 = relax(blk, xin.clone(), xin, 150, 0.1) j1 = jnc_norm(blk, zs2) print(f"{C:>5} {H:>3} {res0:>9.1e} {j0:>10.3f} {j1:>10.3f} {(j1-j0)/100:>11.2e}", flush=True)