1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
"""Causal probe for the stability onset: measure the non-conservative Jacobian norm ‖J_nc(z*)‖
(Hutchinson, the same quantity the jacreg controller penalizes) vs width C, at init and after a
few training steps. If ‖J_nc‖ growth-per-lr-step crosses the contraction margin at the measured
critical scale, the lr_crit onset is DERIVED from dynamics, not a hyperparameter."""
import math, torch
import lt_ep_train as M
from pathlib import Path
import pickle
M.DD = Path('/tmp/lt_ep/data/tinystories_bpe')
M.vocab = pickle.load(open(M.DD / 'meta.pkl', 'rb'))['vocab_size']
from lt_ep_train import EQBlock, get_batch, bptt_step, relax
dev = 'cuda'
def jnc_norm(blk, zs, n=8): # Hutchinson estimate of ‖J_nc(z*)‖_F
tot = 0.0
for _ in range(n):
e = torch.randn_like(zs)
Jv = torch.autograd.functional.jvp(blk.nc_force, zs, e)[1]
tot += (Jv.pow(2).sum() / e.pow(2).sum()).item()
return math.sqrt(tot / n)
print(f"{'C':>5} {'H':>3} {'init_res':>9} {'|Jnc|init':>10} {'|Jnc|@100':>10} {'growth/step':>11}")
for C in (256, 512, 768, 1024, 1536, 2048):
torch.manual_seed(0)
H = C // 32
blk = EQBlock(C, H, 256, 256, attn_mode='thick')
blk.qknorm = True; blk.track = False; blk.li_avg = 0; blk.navg = 1; blk.fnoise = 0; blk.nbrake = 0; blk._cstep = None
with torch.no_grad():
blk.WO.mul_(0.1); blk.pj.mul_(0.1) # resinit 0.1 (match the sweep)
idx, y = get_batch('train', 8, 256)
xin = blk.embed(idx).detach()
zs = relax(blk, xin.clone(), xin, 150, 0.1)
res0 = (relax(blk, zs, xin, 1, 0.1) - zs).norm().item() / zs.norm().item()
j0 = jnc_norm(blk, zs)
# 100 BPTT steps at a FIXED lr to see how fast ‖J_nc‖ grows (the destabilizing drive)
opt = torch.optim.AdamW(blk.allp, lr=1e-3, weight_decay=1e-4)
for _ in range(100):
ix, yy = get_batch('train', 8, 256)
g = bptt_step(blk, ix, yy, 150, 0.1)
opt.zero_grad(set_to_none=True)
for p in blk.allp:
p.grad = g.get(id(p))
torch.nn.utils.clip_grad_norm_(blk.allp, 5.0); opt.step()
zs2 = relax(blk, xin.clone(), xin, 150, 0.1)
j1 = jnc_norm(blk, zs2)
print(f"{C:>5} {H:>3} {res0:>9.1e} {j0:>10.3f} {j1:>10.3f} {(j1-j0)/100:>11.2e}", flush=True)
|