summaryrefslogtreecommitdiff
path: root/ep_run/jnc_scaling.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/jnc_scaling.py')
-rw-r--r--ep_run/jnc_scaling.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/ep_run/jnc_scaling.py b/ep_run/jnc_scaling.py
new file mode 100644
index 0000000..2126d9a
--- /dev/null
+++ b/ep_run/jnc_scaling.py
@@ -0,0 +1,46 @@
+"""Causal probe for the stability onset: measure the non-conservative Jacobian norm ‖J_nc(z*)‖
+(Hutchinson, the same quantity the jacreg controller penalizes) vs width C, at init and after a
+few training steps. If ‖J_nc‖ growth-per-lr-step crosses the contraction margin at the measured
+critical scale, the lr_crit onset is DERIVED from dynamics, not a hyperparameter."""
+import math, torch
+import lt_ep_train as M
+from pathlib import Path
+import pickle
+M.DD = Path('/tmp/lt_ep/data/tinystories_bpe')
+M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size']
+from lt_ep_train import EQBlock, get_batch, bptt_step, relax
+dev = 'cuda'
+
+def jnc_norm(blk, zs, n=8): # Hutchinson estimate of ‖J_nc(z*)‖_F
+ tot = 0.0
+ for _ in range(n):
+ e = torch.randn_like(zs)
+ Jv = torch.autograd.functional.jvp(blk.nc_force, zs, e)[1]
+ tot += (Jv.pow(2).sum() / e.pow(2).sum()).item()
+ return math.sqrt(tot / n)
+
+print(f"{'C':>5} {'H':>3} {'init_res':>9} {'|Jnc|init':>10} {'|Jnc|@100':>10} {'growth/step':>11}")
+for C in (256, 512, 768, 1024, 1536, 2048):
+ torch.manual_seed(0)
+ H = C // 32
+ blk = EQBlock(C, H, 256, 256, attn_mode='thick')
+ blk.qknorm = True; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0; blk.nbrake = 0; blk._cstep = None
+ with torch.no_grad():
+ blk.WO.mul_(0.1); blk.pj.mul_(0.1) # resinit 0.1 (match the sweep)
+ idx, y = get_batch('train', 8, 256)
+ xin = blk.embed(idx).detach()
+ zs = relax(blk, xin.clone(), xin, 150, 0.1)
+ res0 = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item()
+ j0 = jnc_norm(blk, zs)
+ # 100 BPTT steps at a FIXED lr to see how fast ‖J_nc‖ grows (the destabilizing drive)
+ opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
+ for _ in range(100):
+ ix, yy = get_batch('train', 8, 256)
+ g = bptt_step(blk, ix, yy, 150, 0.1)
+ opt.zero_grad(set_to_none=True)
+ for p in blk.allp:
+ p.grad = g.get(id(p))
+ torch.nn.utils.clip_grad_norm_(blk.allp, 5.0); opt.step()
+ zs2 = relax(blk, xin.clone(), xin, 150, 0.1)
+ j1 = jnc_norm(blk, zs2)
+ print(f"{C:>5} {H:>3} {res0:>9.1e} {j0:>10.3f} {j1:>10.3f} {(j1-j0)/100:>11.2e}", flush=True)