diff options
Diffstat (limited to 'ep_run/lt_ep_train.py')
| -rw-r--r-- | ep_run/lt_ep_train.py | 630 |
1 files changed, 630 insertions, 0 deletions
diff --git a/ep_run/lt_ep_train.py b/ep_run/lt_ep_train.py new file mode 100644 index 0000000..9974bd8 --- /dev/null +++ b/ep_run/lt_ep_train.py @@ -0,0 +1,630 @@ +"""option 2 / H2: train a full EP equilibrium transformer block on Shakespeare char-LM. + +One block = token state z relaxed to a fixed point of + F(z) = -(z - x_in) (input clamp; x_in = embed(idx)) + - dE_mem/dz (Hopfield memory E_mem = -sum relu(z Wm)^2 ; CONSERVATIVE = FFN) + + s*(causal_attn(z) - c*z) (damped causal attention ; NON-conservative) +Readout logits = z* Whead ; loss = next-token cross-entropy. + +Train modes: + ep : free + +/-beta nudged equilibria. Block+embed params via vector-field gradient + <a, dF/dtheta(z*)> with the AEP correction (clipped) on the attention part; readout + head via its own local gradient dCE/dWhead. NO backprop through the relaxation. + bptt : backprop through the unrolled relaxation (exact-gradient reference, same architecture). +Stabilisation from G: damping c, clipped AEP correction, weight-norm caps, best-val checkpoint. +""" +import argparse, math, pickle, time, json, os, numpy as np, torch, torch.nn.functional as F +from pathlib import Path + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +DD = Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe') +vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] + + +def get_batch(split, B, T): + data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r') + ix = torch.randint(len(data) - T - 1, (B,)) + x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) + return x.to(dev), y.to(dev) + + +class EQBlock: + def __init__(self, C, H, Mm, T, s=1.0, c=1.0, attn_mode='real', gamma=0.25): + self.C, self.H, self.dh, self.s, self.c, self.T = C, H, C // H, s, c, T + self.attn_mode, self.gamma = attn_mode, gamma + self.fnoise = 0.0 # optics model: mult. noise per force eval + g = lambda *sh, sc: (torch.randn(*sh, device=dev) * sc).requires_grad_(True) + self.tok = g(vocab, C, sc=0.02); self.pos = g(T, C, sc=0.02) + self.WQ = g(C, C, sc=1 / math.sqrt(C)); self.WK = g(C, C, sc=1 / math.sqrt(C)) + self.WV = g(C, C, sc=1 / math.sqrt(C)); self.WO = g(C, C, sc=1 / math.sqrt(C)) + self.Wm = g(C, Mm, sc=0.3 / math.sqrt(C)); self.Wh = g(C, vocab, sc=1 / math.sqrt(C)) + self.P = g(C, C, sc=1 / math.sqrt(C)); self.Q = g(C, C, sc=1 / math.sqrt(C)) # monDEQ monotone op + self.mono_m = 1.0 + z1 = lambda n, v: torch.full((n,), float(v), device=dev).requires_grad_(True) + self.ln1g = z1(C, 1); self.ln1b = z1(C, 0); self.ln2g = z1(C, 1); self.ln2b = z1(C, 0) # LN affine + self.fc = g(C, 4 * C, sc=1 / math.sqrt(C)); self.fcb = z1(4 * C, 0) # untied 4x FFN + self.pj = g(4 * C, C, sc=1 / math.sqrt(4 * C)); self.pjb = z1(C, 0) + self.cmask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev)) + self.block = [self.tok, self.pos, self.WQ, self.WK, self.WV, self.WO, self.Wm, self.P, self.Q, + self.ln1g, self.ln1b, self.ln2g, self.ln2b, self.fc, self.fcb, self.pj, self.pjb] # in the force + self.allp = self.block + [self.Wh] + self.capw = (self.WQ, self.WK, self.WV, self.WO, self.Wm, self.Wh, self.fc, self.pj) + self.caps = {id(w): w.detach().norm().item() * 3.0 for w in self.capw} + + def embed(self, idx): + return self.tok[idx] + self.pos[None] + + def attn(self, z): + B = z.size(0) + q = (z @ self.WQ).view(B, self.T, self.H, self.dh).transpose(1, 2) + k = (z @ self.WK).view(B, self.T, self.H, self.dh).transpose(1, 2) + v = (z @ self.WV).view(B, self.T, self.H, self.dh).transpose(1, 2) + if getattr(self, 'qknorm', False): # Qwen3-style q/k RMSNorm: bounds logits, tames J + q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6) + k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6) + a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh) + a = torch.softmax(a.masked_fill(~self.cmask, float('-inf')), -1) + return (a @ v).transpose(1, 2).reshape(B, self.T, self.C) @ self.WO + + def attn_energy(self, z): # conservative LSE attention energy (tied value) + B = z.size(0) + q = (z @ self.WQ).view(B, self.T, self.H, self.dh).transpose(1, 2) + k = (z @ self.WK).view(B, self.T, self.H, self.dh).transpose(1, 2) + a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh) + a = a.masked_fill(~self.cmask, float('-inf')) + return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * a, dim=-1).sum() + + def Emem(self, z): + return -(F.relu(z @ self.Wm) ** 2).sum() + + def tforce(self, z, xin): # pure thick force (no grad machinery) -> torch.compile + h1 = F.layer_norm(z, (self.C,), self.ln1g, self.ln1b) + h2 = F.layer_norm(z, (self.C,), self.ln2g, self.ln2b) + ff = F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb + return -(z - xin) + self.attn(h1) + ff - self.c * z + + def _noisy(self, t): # optics model: per-pass multiplicative noise + if self.fnoise > 0: + return t * (1 + self.fnoise * torch.randn_like(t)) + return t + + def nc_force(self, z): # non-conservative part of the force (for AEP/jacreg) + if self.attn_mode == 'thick': + h1 = F.layer_norm(z, (self.C,), self.ln1g, self.ln1b) + h2 = F.layer_norm(z, (self.C,), self.ln2g, self.ln2b) + return self._noisy(self.attn(h1) + (F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb)) + return self._noisy(self.s * self.attn(z)) + + def force(self, z, xin, cg=False): + with torch.enable_grad(): + zr = z if (cg and z.requires_grad) else z.detach().requires_grad_(True) + if self.attn_mode == 'thick': # DEQ-transformer block: LN + untied 4x FFN + residual + h1 = F.layer_norm(zr, (self.C,), self.ln1g, self.ln1b) + h2 = F.layer_norm(zr, (self.C,), self.ln2g, self.ln2b) + ff = F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb + return -(zr - xin) + self._noisy(self.attn(h1) + ff) - self.c * zr # -c*z: initial contraction + if self.attn_mode == 'mono': # monDEQ: structurally-monotone contraction + gm, = torch.autograd.grad(self.Emem(zr), zr, create_graph=cg) + PtP = self.P.t() @ self.P # PSD -> sym(J) = -(mI+PtP) < 0 (guaranteed) + f = (-(self.mono_m * zr + zr @ PtP) + zr @ (self.Q - self.Q.t()).t() + + xin - gm + self.s * self.attn(zr)) + return f + E = 0.5 * ((zr - xin) ** 2).sum() + self.Emem(zr) + if self.attn_mode == 'energy': # attention folded into the energy (conservative) + E = E + self.attn_energy(zr) + 0.5 * self.c * (zr ** 2).sum() # confinement -> bounded below + gz, = torch.autograd.grad(E, zr, create_graph=cg) + f = -gz + if self.attn_mode == 'real': # non-conservative attention + damping + f = f + self.s * (self.attn(zr) - self.c * zr) + return f + + +def relax(blk, z, xin, steps, eps): + cstep = getattr(blk, '_cstep', None) + if cstep is not None and blk.fnoise == 0.0: # compiled pure-thick free-phase fast path + with torch.no_grad(): + for _ in range(steps): + z = cstep(z, xin) + return z.detach() + for _ in range(steps): + with torch.no_grad(): + z = z + eps * blk.force(z, xin).detach() + return z.detach() + + +def ce(blk, z, y): + return F.cross_entropy((z @ blk.Wh).reshape(-1, vocab), y.reshape(-1)) + + +def ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, + corr_every=1, res_gate=0.0, resreg=0.0, eigreg=0.0, eig_margin=1.0): + xin0 = blk.embed(idx).detach() + zs = relax(blk, xin0.clone(), xin0, T1, eps) + res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) + res_used = res + zT, resT1 = zs, res # the T1 free-phase state (what eval/BPTT use), BEFORE refinement + if t1max > T1: # estimator refinement: relax further until tight + rnow, t = res, T1 # (controller signal `res` stays measured at T1) + while t < t1max and rnow > res_est: + zs = relax(blk, zs, xin0, 50, eps); t += 50 + rnow = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) + res_used = rnow + if res_gate > 0 and res_used > res_gate: # validity gate: off-equilibrium the EP update is + grads = {} # undefined -> apply ONLY the homeostat (jacreg) and + if jacreg > 0: # skip the nudge entirely (fast recovery steps) + er = torch.randn_like(zs) + with torch.enable_grad(): + Jv = torch.autograd.functional.jvp(blk.nc_force, zs.detach(), er, create_graph=True)[1] + R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() + gr = torch.autograd.grad(R, blk.block, allow_unused=True) + grads = {id(p): g for p, g in zip(blk.block, gr) if g is not None} + return grads, res + def nudge(sign): + z = zs.clone() + for _ in range(T2): + with torch.enable_grad(): + zz = z.detach().requires_grad_(True) + g, = torch.autograd.grad(ce(blk, zz, y), zz) + g = g.clamp(-2.0, 2.0) # clip nudge so it can't blow up relax + with torch.no_grad(): + f = blk.force(z, xin0).detach() - sign * beta * g + if blk.attn_mode in ('real', 'thick'): # AEP correction (full non-conservative part) + v = (z - zs).detach() + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1] + JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1] + corr = Jv - JTv + cn, fn = corr.norm(), f.norm() + 1e-8 + f = f - (corr * (fn / cn) if cn > fn else corr) + z = z + eps * f + return z.detach() + if holo == 2 and t2sel > 0: # adaptive-T2, phase-batched fast path (validated ==) + from holo_ep import holo_a_select2, holo_a_track + K = max(1, getattr(blk, 'navg', 1)) # restart-averaging: noise / sqrt(K) + acc = None + for _ in range(K): + if getattr(blk, 'track', False): # common-mode-tracking AEP (loose-tolerant) + ai, _ = holo_a_track(blk, zs, xin0, y, hr, t2sel, eps) + else: + ai, _ = holo_a_select2(blk, zs, xin0, y, hr, t2sel, eps, li=getattr(blk, 'li_avg', 0)) + acc = ai if acc is None else acc + ai + a = acc / K + elif holo > 0 and t2sel > 0: # adaptive-T2 via hindsight snapshot selection + from holo_ep import holo_a_select + a, _ = holo_a_select(blk, zs, xin0, y, holo, hr, t2sel, eps, corr_every=corr_every) + elif holo > 0: # holomorphic nudge (clamp-free, Cauchy readout) + from holo_ep import holo_a + a, _ = holo_a(blk, zs, xin0, y, holo, hr, T2, eps) + else: + zp, zm = nudge(+1), nudge(-1) + a = ((zm - zp) / (2 * beta)).detach() + grads = {} + with torch.enable_grad(): + xin = blk.embed(idx) # live (for tok/pos grad through clamp) + f = blk.force(zs.detach(), xin, cg=True) + gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True) + for p, gv in zip(blk.block, gblk): + grads[id(p)] = gv + with torch.enable_grad(): + gh, = torch.autograd.grad(ce(blk, zs.detach(), y), blk.Wh) # readout local gradient + grads[id(blk.Wh)] = gh + if jacreg > 0: # soft Lyapunov: penalize non-conservative Jacobian norm + er = torch.randn_like(zs) + with torch.enable_grad(): + Jv = torch.autograd.functional.jvp(blk.nc_force, zs.detach(), er, create_graph=True)[1] + R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() # Hutchinson est of ||J_nc||_F^2 + gr = torch.autograd.grad(R, blk.block, allow_unused=True) + for p, g in zip(blk.block, gr): + if g is not None: + grads[id(p)] = g if grads.get(id(p)) is None else grads[id(p)] + g + if resreg > 0 and resT1 > 7e-4: # defend z_T1 (BPTT gets this implicitly; EP at z* doesn't) + with torch.enable_grad(): + Fz = blk.tforce(zT, xin0) # deterministic thick force at z_T1 (params live, zT/xin0 detached) + Rr = (eps * Fz).pow(2).sum() / (zT.pow(2).sum() + 1e-9) # ~ (T1 residual)^2 + grr = torch.autograd.grad(Rr, blk.block, allow_unused=True) + ratio = resreg * min(1.0, resT1 / 2e-2) # ramp 0->resreg as res 7e-4->2e-2, capped + gtask = math.sqrt(sum(float((grads[id(p)] ** 2).sum()) for p in blk.block if grads.get(id(p)) is not None) + 1e-20) + gres = math.sqrt(sum(float((g ** 2).sum()) for g in grr if g is not None) + 1e-20) + lam = ratio * gtask / gres # scale penalty to `ratio` of the task-grad norm + for p, g in zip(blk.block, grr): + if g is not None: + grads[id(p)] = g * lam if grads.get(id(p)) is None else grads[id(p)] + lam * g + if eigreg > 0: # #2: leading-abscissa control (surgical, one-sided; alt to jacreg) + from eig_control import eig_penalty + ge, _om = eig_penalty(blk, zs, eigreg, eig_margin, blk.__dict__.setdefault('_eigcache', {})) + for pid, g in ge.items(): + grads[pid] = g if grads.get(pid) is None else grads[pid] + g + return grads, res + + +class Lion(torch.optim.Optimizer): + """Chen et al. 2023. Analog-hardware rationale: sign updates = fixed-amplitude pulses + (kills device write-nonlinearity), magnitude-noise immune, one momentum cap per weight.""" + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0, lars=False): + super().__init__(params, dict(lr=lr, betas=betas, weight_decay=weight_decay, lars=lars)) + + @torch.no_grad() + def step(self): + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + b1, b2 = group['betas'] + st = self.state.setdefault(p, {}) + if 'm' not in st: + st['m'] = torch.zeros_like(p) + u = (b1 * st['m'] + (1 - b1) * p.grad).sign() + lr = group['lr'] + if group['lars']: # per-tensor trust ratio: one gain line per array + lr = lr * (p.norm() / (u.norm() + 1e-12)).item() + p.mul_(1 - lr * group['weight_decay']) + p.add_(u, alpha=-lr) + st['m'].mul_(b2).add_(p.grad, alpha=1 - b2) + + +def bptt_step(blk, idx, y, T1, eps, jacreg=0.0): + xin = blk.embed(idx) + z = xin.detach().requires_grad_(True) * 0 + xin # init = embedding (keeps graph to emb) + for _ in range(T1): + z = z + eps * blk.force(z, xin, cg=True) + g = torch.autograd.grad(ce(blk, z, y), blk.allp, allow_unused=True) + gd = {id(p): gv for p, gv in zip(blk.allp, g)} + if jacreg > 0: # same soft Lyapunov penalty as ep mode (fair control) + er = torch.randn_like(z) + with torch.enable_grad(): + Jv = torch.autograd.functional.jvp(blk.nc_force, z.detach(), er, create_graph=True)[1] + R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() + gr = torch.autograd.grad(R, blk.block, allow_unused=True) + for p, gv in zip(blk.block, gr): + if gv is not None: + gd[id(p)] = gv if gd.get(id(p)) is None else gd[id(p)] + gv + return gd + + +@torch.no_grad() +def evaluate(blk, T1, eps, nb=8, B=32): + tot = 0.0 + for _ in range(nb): + idx, y = get_batch('val', B, blk.T) + xin = blk.embed(idx).detach() + z = relax(blk, xin.clone(), xin, T1, eps) + tot += ce(blk, z, y).item() + return tot / nb + + +def specnorm_weight_items(blk): + items = [] + qkv = None + for name in ('WQKV', 'Wqkv', 'W_qkv', 'qkv'): + if hasattr(blk, name): + qkv = (name, getattr(blk, name)) + break + if qkv is not None: + items.append(qkv) + else: + items.extend((name, getattr(blk, name)) for name in ('WQ', 'WK', 'WV') if hasattr(blk, name)) + items.extend((name, getattr(blk, name)) for name in ('WO', 'fc', 'pj') if hasattr(blk, name)) + return [(name, w) for name, w in items if w.ndim >= 2] + + +@torch.no_grad() +def power_sigma(W, u, iters=2): + M = W.detach().reshape(W.shape[0], -1) + if u is None or u.shape != (M.shape[0],) or u.device != W.device or u.dtype != W.dtype: + u = F.normalize(torch.randn(M.shape[0], device=W.device, dtype=W.dtype), dim=0, eps=1e-12) + for _ in range(iters): + v = F.normalize(M.t().mv(u), dim=0, eps=1e-12) + u = F.normalize(M.mv(v), dim=0, eps=1e-12) + sigma = u.dot(M.mv(v)).abs() + return sigma, u.detach() + + +@torch.no_grad() +def project_specnorm_(items, cache, bound): + max_before, max_after, clamped = 0.0, 0.0, [] + for name, W in items: + sigma, u = power_sigma(W, cache.get(id(W))) + cache[id(W)] = u + before = float(sigma) + after = before + if before > bound: + scale = bound / (before + 1e-12) + W.mul_(scale) + after = bound + clamped.append(name) + max_before = max(max_before, before) + max_after = max(max_after, after) + return max_before, max_after, clamped + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--mode', choices=['ep', 'bptt'], required=True) + ap.add_argument('--steps', type=int, default=2000); ap.add_argument('--B', type=int, default=32) + ap.add_argument('--T', type=int, default=64); ap.add_argument('--C', type=int, default=128) + ap.add_argument('--H', type=int, default=4); ap.add_argument('--Mm', type=int, default=256) + ap.add_argument('--T1', type=int, default=80); ap.add_argument('--T2', type=int, default=15) + ap.add_argument('--eps', type=float, default=0.1); ap.add_argument('--beta', type=float, default=0.02) + ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--log', type=int, default=100) + ap.add_argument('--warmup', type=int, default=0) # linear lr warmup steps (big-model stability) + ap.add_argument('--state', type=str, default='') # periodic FULL-state path (weights+opt+sched+step) + ap.add_argument('--resume', action='store_true') # resume from --state if it exists (Colab timeouts) + ap.add_argument('--save_every', type=int, default=0) # full-state save cadence (0=every --log); set small on Colab + ap.add_argument('--c', type=float, default=1.0); ap.add_argument('--capx', type=float, default=3.0) + ap.add_argument('--attn_mode', choices=['real', 'energy', 'mono', 'thick'], default='real') + ap.add_argument('--ccap', type=float, default=8.0) + ap.add_argument('--specnorm', type=float, default=0.0) # hard post-step spectral projection bound; 0=off + ap.add_argument('--jacreg', type=float, default=0.0) # Bai 2021: soft Jacobian-norm (Lyapunov) penalty + ap.add_argument('--jr_max', type=float, default=16.0) # adaptive jacreg ceiling (ramps up vs residual) + ap.add_argument('--res_target', type=float, default=5e-3) # continuous controller target residual + ap.add_argument('--jr_floor', type=float, default=None) # controller floor; default=--jacreg (legacy: never off) + ap.add_argument('--res_ema', type=float, default=0.0) # EMA on residual signal (0=off); kills controller thrash + ap.add_argument('--jr_lrcouple', action='store_true') # anneal the λ floor with the lr schedule (late-drift fix) + ap.add_argument('--holo', type=int, default=0) # holomorphic EP: N circle points (0=off) + ap.add_argument('--hr', type=float, default=0.02) # holomorphic nudge radius |beta| + ap.add_argument('--pema', type=float, default=0.0) # parameter EMA decay (0=off); tames late wander + ap.add_argument('--t1max', type=int, default=0) # adaptive free phase: extend up to t1max... + ap.add_argument('--res_est', type=float, default=1e-4) # ...until this residual (estimator validity) + ap.add_argument('--t2sel', type=int, default=0) # adaptive T2: snapshot-selection cap (0=off) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--data', type=str, default='/tmp/lt_ep/data/shakespeare_char') + ap.add_argument('--ckpt', type=str, default='') # save best weights (raw+ema) here + ap.add_argument('--corr_every', type=int, default=1) # recompute AEP corr every k nudge steps + ap.add_argument('--tf32', action='store_true') # tf32 matmuls (check res floor first!) + ap.add_argument('--abort_res', type=float, default=0.1) # kill switch: res above this 100 steps straight + ap.add_argument('--res_gate', type=float, default=0.0) # validity gate: skip task grads above this res + ap.add_argument('--wsd', type=float, default=0.0) # WSD: hold peak lr, cosine-decay only the last wsd fraction + ap.add_argument('--resreg', type=float, default=0.0) # T1-residual penalty: defend z_T1 (cap ratio vs task grad); run res_gate=0 + ap.add_argument('--eigreg', type=float, default=0.0) # #2: leading-abscissa (numerical-abscissa) control — surgical alt to jacreg + ap.add_argument('--eig_margin', type=float, default=1.0) # penalize omega(J_nc) above this (free-phase Hopf boundary ~ 1+c) + ap.add_argument('--diag_cos', type=int, default=0) # #1: every N steps, log cos(EP grad, exact BPTT grad) + res + ap.add_argument('--fingerprint', action='store_true') # load --init_ckpt, print (res,cos,abscissa,val) fingerprint, exit + ap.add_argument('--opt', choices=['adamw', 'lion', 'lionlars', 'sgdm', 'sgdsai'], default='adamw') + ap.add_argument('--wd', type=float, default=1e-4) + ap.add_argument('--fnoise', type=float, default=0.0) # optics/device twin: mult. noise per force eval + ap.add_argument('--wq_bits', type=int, default=0) # weights projected to N bits each step (0=off) + ap.add_argument('--wmis', type=float, default=0.0) # static per-device mismatch sigma (0=off) + ap.add_argument('--li_avg', type=int, default=0) # lock-in integration window (0=snapshot mode) + ap.add_argument('--navg', type=int, default=1) # restart-averaged contrast estimates per update + ap.add_argument('--track', action='store_true') # common-mode-tracking AEP correction + ap.add_argument('--rt_final', type=float, default=0.0) # anneal res_target to this (0=off), 25%-75% of run + ap.add_argument('--nudge_brake', type=float, default=0.0) # kappa: anchor spring during nudge (Tikhonov adjoint) + ap.add_argument('--init_ckpt', type=str, default='') # warm-start weights from a saved ckpt + ap.add_argument('--qknorm', action='store_true') # Qwen3-style q/k RMSNorm in attention + ap.add_argument('--compile', action='store_true') # torch.compile the free-phase relaxation (thick) + ap.add_argument('--resinit', type=float, default=1.0) # scale WO,pj at init (ReZero/Fixup: small=near-identity block) + cfg = ap.parse_args() + if cfg.specnorm < 0: + raise SystemExit("--specnorm must be non-negative") + global DD, vocab + DD = Path(cfg.data) + vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] + if cfg.tf32: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(cfg.seed) + blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=cfg.c, attn_mode=cfg.attn_mode) + for w in blk.capw: + blk.caps[id(w)] = w.detach().norm().item() * cfg.capx + if cfg.opt in ('lion', 'lionlars'): + opt = Lion(blk.allp, lr=cfg.lr, weight_decay=cfg.wd, lars=(cfg.opt == 'lionlars')) + elif cfg.opt == 'sgdm': + opt = torch.optim.SGD(blk.allp, lr=cfg.lr, momentum=0.95, weight_decay=cfg.wd, nesterov=True) + elif cfg.opt == 'sgdsai': + # EP-SaI (SGD-SaI, arXiv:2412.11768, adapted to EP gradients): per-tensor lr from the + # init-time gradient SNR, frozen — hardware: one gain line per array, set at calibration. + gs = {id(p): [] for p in blk.allp} + for _ in range(12): + idx0, y0 = get_batch('train', cfg.B, cfg.T) + g0, _ = ep_step(blk, idx0, y0, cfg.T1, cfg.T2, cfg.eps, cfg.beta, 0.0, cfg.holo, cfg.hr, + cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, 0.0) + for p in blk.allp: + if g0.get(id(p)) is not None: + gs[id(p)].append(g0[id(p)].detach().clone()) + sc = {} + for p in blk.allp: + if gs[id(p)]: + S = torch.stack(gs[id(p)]) + sc[id(p)] = (S.mean(0).norm() / (S.std(0).norm() + 1e-12)).item() + mx = max(sc.values()) + print("[sgdsai] per-tensor lr scales: " + " ".join(f"{v/mx:.3f}" for v in sc.values()), flush=True) + opt = torch.optim.SGD([dict(params=[p], lr=cfg.lr * sc.get(id(p), mx) / mx) for p in blk.allp], + momentum=0.95, weight_decay=cfg.wd, nesterov=True) + else: + opt = torch.optim.AdamW(blk.allp, lr=cfg.lr, weight_decay=cfg.wd) + if cfg.warmup > 0 or cfg.wsd > 0: # warmup -> (WSD hold peak) -> cosine decay + _w = cfg.warmup # contraction before large steps kick weights out of basin + def _lrl(s): + if s < _w: + return (s + 1) / _w + _ds = int((1 - cfg.wsd) * cfg.steps) if cfg.wsd > 0 else _w # WSD decay-start: hold peak lr until here + if s < _ds: + return 1.0 + p = (s - _ds) / max(1, cfg.steps - _ds) + return 0.05 + 0.475 * (1 + math.cos(math.pi * min(1.0, p))) + sched = torch.optim.lr_scheduler.LambdaLR(opt, _lrl) + else: + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.05) + if cfg.init_ckpt: + ckl = torch.load(cfg.init_ckpt, map_location=dev) + with torch.no_grad(): + for p, w in zip(blk.allp, ckl['allp']): + p.copy_(w.to(dev)) + print(f"[init] warm-start from {cfg.init_ckpt} (step {ckl.get('step')}, best {ckl.get('best', float('nan')):.4f})", flush=True) + xin = blk.embed(*get_batch('train', cfg.B, cfg.T)[:1]).detach() + r = (relax(blk, relax(blk, xin.clone(), xin, 200, cfg.eps), xin, 1, cfg.eps) - relax(blk, xin.clone(), xin, 200, cfg.eps)).norm().item() + print(f"[{cfg.mode}] residual~{r:.1e} | C={cfg.C} H={cfg.H} Mm={cfg.Mm} T1={cfg.T1} T2={cfg.T2}", flush=True) + best, t0, jr, rs = 9.9, time.time(), cfg.jacreg, None + pema = [p.detach().clone() for p in blk.allp] if cfg.pema > 0 else None + badct = 0 + blk.fnoise = cfg.fnoise + blk.li_avg = cfg.li_avg + blk.navg = cfg.navg + blk.track = cfg.track + blk.nbrake = cfg.nudge_brake + blk.qknorm = cfg.qknorm + if cfg.resinit != 1.0: # near-identity block at init (contractive) -> stable big-width start + with torch.no_grad(): + blk.WO.mul_(cfg.resinit); blk.pj.mul_(cfg.resinit) + spec_items = specnorm_weight_items(blk) + spec_cache = {} + if cfg.specnorm > 0: + shapes = " ".join(f"{name}{tuple(W.shape)}" for name, W in spec_items) + print(f"[specnorm] hard post-step projection: sigma_max <= {cfg.specnorm:g} on {shapes}", flush=True) + blk._cstep = None + if cfg.compile and cfg.attn_mode == 'thick': + _ee = cfg.eps + blk._cstep = torch.compile(lambda z, xin: z + _ee * blk.tforce(z, xin)) + mis = None + if cfg.wmis > 0: # fixed fabrication mismatch (same devices all run) + gm = torch.Generator().manual_seed(1234) + mis = [(1 + cfg.wmis * torch.randn(p.shape, generator=gm)).clamp(0.2, 5.0).to(dev) for p in blk.allp] + hw_on = cfg.wq_bits > 0 or mis is not None + + def hw_swap(): # measure physics on the imperfect device copy; + saved = [p.detach().clone() for p in blk.allp] # masters stay fp32 (program-verify model) + with torch.no_grad(): + for i, p in enumerate(blk.allp): + w = p * mis[i] if mis is not None else p.detach().clone() + if cfg.wq_bits > 0: + d = w.abs().max() / (2 ** (cfg.wq_bits - 1) - 1) + 1e-12 + w = torch.round(w / d) * d + p.copy_(w) + return saved + + def hw_restore(saved): + with torch.no_grad(): + for p, s in zip(blk.allp, saved): + p.copy_(s) + + start_step = 1 + if cfg.resume and cfg.state and os.path.exists(cfg.state): # Colab-timeout resume: full state + st = torch.load(cfg.state, map_location=dev) + with torch.no_grad(): + for p, w in zip(blk.allp, st['allp']): + p.copy_(w.to(dev)) + if pema is not None and st.get('pema') is not None: + pema = [s.to(dev) for s in st['pema']] + opt.load_state_dict(st['opt']); sched.load_state_dict(st['sched']) + start_step = st['step'] + 1; jr = st['jr']; rs = st['rs']; best = st['best'] + print(f"[resume] from {cfg.state}: step {start_step}, best {best:.4f}, jr {jr:.1f}", flush=True) + + def save_state(step): + if not cfg.state: + return + torch.save({'allp': [p.detach().cpu() for p in blk.allp], + 'pema': [s.cpu() for s in pema] if pema is not None else None, + 'opt': opt.state_dict(), 'sched': sched.state_dict(), + 'step': step, 'jr': jr, 'rs': rs, 'best': best}, cfg.state + '.tmp') + os.replace(cfg.state + '.tmp', cfg.state) # atomic: survive a mid-write timeout + + if cfg.fingerprint: # study s2000 vs other ckpts: print the operator's 4-D fingerprint + from diag_cos import fingerprint + fp = fingerprint(blk, cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel) + print(f"[fingerprint] ckpt={cfg.init_ckpt or 'scratch'} | res={fp['res']:.2e} cos(EP,BPTT)={fp['cos']:.4f} " + f"num_abscissa={fp['num_abscissa']:+.4f} val={fp['val']:.4f}", flush=True) + return + for step in range(start_step, cfg.steps + 1): + idx, y = get_batch('train', cfg.B, cfg.T) + if cfg.mode == 'ep': + sw = hw_swap() if hw_on else None + grads, res = ep_step(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, jr, cfg.holo, cfg.hr, + cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, cfg.res_gate, cfg.resreg, + cfg.eigreg, cfg.eig_margin) + if sw is not None: + hw_restore(sw) + if cfg.jacreg > 0: # continuous controller: drive residual -> res_target (smooth) + flo = cfg.jacreg if cfg.jr_floor is None else cfg.jr_floor + if cfg.jr_lrcouple: + flo *= sched.get_last_lr()[0] / cfg.lr + rtgt = cfg.res_target + if cfg.rt_final > 0: # stiffness anneal: tight start -> loose mid/late + u = min(1.0, max(0.0, (step / cfg.steps - 0.25) / 0.5)) + rtgt = math.exp((1 - u) * math.log(cfg.res_target) + u * math.log(cfg.rt_final)) + rs = res if rs is None else cfg.res_ema * rs + (1 - cfg.res_ema) * res + jr = min(cfg.jr_max, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / rtgt)))) + else: # damping feedback (no jacreg) + if res > 1e-3: + blk.c = min(cfg.ccap, blk.c * 1.3) + elif res < 2e-4: + blk.c = max(0.5, blk.c * 0.97) + else: + grads = bptt_step(blk, idx, y, cfg.T1, cfg.eps, jr if cfg.jacreg > 0 else 0.0) + with torch.no_grad(): # is BPTT's optimum contractive? (free-phase residual) + xinb = blk.embed(idx).detach() + zsb = relax(blk, xinb.clone(), xinb, cfg.T1, cfg.eps) + res = (relax(blk, zsb, xinb, 1, cfg.eps) - zsb).norm().item() / (zsb.norm().item() + 1e-9) + if cfg.jacreg > 0: # same residual-driven λ controller as ep mode + flo = cfg.jacreg if cfg.jr_floor is None else cfg.jr_floor + if cfg.jr_lrcouple: + flo *= sched.get_last_lr()[0] / cfg.lr + rs = res if rs is None else cfg.res_ema * rs + (1 - cfg.res_ema) * res + jr = min(cfg.jr_max, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / cfg.res_target)))) + badct = badct + 1 if (cfg.abort_res > 0 and res > cfg.abort_res) else 0 + if badct >= 100: # containment lost and not recovering: stop, keep best ckpt + print(f" ABORT at step {step}: res>{cfg.abort_res} for 100 consecutive steps (best {best:.4f})", flush=True) + break + ok = all((g is None) or torch.isfinite(g).all() for g in grads.values()) + if not ok: + print(f" step {step}: non-finite, skip", flush=True); continue + opt.zero_grad(set_to_none=True) + for p in blk.allp: + p.grad = grads.get(id(p)) + torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) + opt.step() + spec_stats = None + with torch.no_grad(): + if cfg.specnorm > 0: + spec_stats = project_specnorm_(spec_items, spec_cache, cfg.specnorm) + else: + for p in blk.capw: + pn = p.norm(); cap = blk.caps[id(p)] + if pn > cap: + p.mul_(cap / pn) + sched.step() + if spec_stats is not None and (step == start_step or step % cfg.log == 0): + sb, sa, names = spec_stats + cname = ",".join(names) if names else "none" + print(f" specnorm step {step}: max sigma before={sb:.4f} after={sa:.4f} bound={cfg.specnorm:.4f} clamped={cname}", flush=True) + if pema is not None: + with torch.no_grad(): + for s, p in zip(pema, blk.allp): + s.mul_(cfg.pema).add_(p.detach(), alpha=1 - cfg.pema) + if cfg.save_every and step % cfg.save_every == 0 and step % cfg.log != 0: + save_state(step) # mid-interval state save (Colab: cap worst-case loss) + if cfg.diag_cos and step % cfg.diag_cos == 0: # #1: gradient-alignment trajectory (scratch vs warm) + from diag_cos import cos_ep_bptt + _c, _r = cos_ep_bptt(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.holo, cfg.hr, + cfg.t1max, cfg.res_est, cfg.t2sel) + print(f" [diag] step {step}: cos(EP,BPTT)={_c:.4f} res={_r:.1e}", flush=True) + if step % cfg.log == 0: + prevb = best + sw = hw_swap() if hw_on else None + v = evaluate(blk, cfg.T1, cfg.eps) + if sw is not None: + hw_restore(sw) + best = min(best, v) + etag = "" + if pema is not None: + with torch.no_grad(): + raw = [p.detach().clone() for p in blk.allp] + for p, s in zip(blk.allp, pema): + p.copy_(s) + ve = evaluate(blk, cfg.T1, cfg.eps) + for p, r in zip(blk.allp, raw): + p.copy_(r) + best = min(best, ve); etag = f" ema={ve:.4f}" + if cfg.ckpt and best < prevb: + torch.save({'allp': [p.detach().cpu() for p in blk.allp], + 'pema': [s.cpu() for s in pema] if pema is not None else None, + 'step': step, 'best': best}, cfg.ckpt) + print(f"step {step:4d}/{cfg.steps} | val CE {v:.4f}{etag} (best {best:.4f}) | jr={jr:.1f} res={res:.1e} | {step/(time.time()-t0):.2f} it/s", flush=True) + save_state(step) # full-state checkpoint each log interval (Colab resume) + print(f"[{cfg.mode}] DONE best val CE {best:.4f} (random baseline ln({vocab})={math.log(vocab):.3f})", flush=True) + out_dir = Path('runs') + out_dir.mkdir(exist_ok=True) + json.dump({'mode': cfg.mode, 'best_val_ce': best}, open(out_dir / f'H2_{cfg.mode}.json', 'w')) + + +if __name__ == '__main__': + main() |
