"""option 2 / H2: train a full EP equilibrium transformer block on Shakespeare char-LM. One block = token state z relaxed to a fixed point of F(z) = -(z - x_in) (input clamp; x_in = embed(idx)) - dE_mem/dz (Hopfield memory E_mem = -sum relu(z Wm)^2 ; CONSERVATIVE = FFN) + s*(causal_attn(z) - c*z) (damped causal attention ; NON-conservative) Readout logits = z* Whead ; loss = next-token cross-entropy. Train modes: ep : free + +/-beta nudged equilibria. Block+embed params via vector-field gradient with the AEP correction (clipped) on the attention part; readout head via its own local gradient dCE/dWhead. NO backprop through the relaxation. bptt : backprop through the unrolled relaxation (exact-gradient reference, same architecture). Stabilisation from G: damping c, clipped AEP correction, weight-norm caps, best-val checkpoint. """ import argparse, math, pickle, time, json, os, numpy as np, torch, torch.nn.functional as F from pathlib import Path dev = 'cuda' if torch.cuda.is_available() else 'cpu' DD = Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe') vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] def get_batch(split, B, T): data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r') ix = torch.randint(len(data) - T - 1, (B,)) x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix]) return x.to(dev), y.to(dev) class EQBlock: def __init__(self, C, H, Mm, T, s=1.0, c=1.0, attn_mode='real', gamma=0.25): self.C, self.H, self.dh, self.s, self.c, self.T = C, H, C // H, s, c, T self.attn_mode, self.gamma = attn_mode, gamma self.fnoise = 0.0 # optics model: mult. noise per force eval g = lambda *sh, sc: (torch.randn(*sh, device=dev) * sc).requires_grad_(True) self.tok = g(vocab, C, sc=0.02); self.pos = g(T, C, sc=0.02) self.WQ = g(C, C, sc=1 / math.sqrt(C)); self.WK = g(C, C, sc=1 / math.sqrt(C)) self.WV = g(C, C, sc=1 / math.sqrt(C)); self.WO = g(C, C, sc=1 / math.sqrt(C)) self.Wm = g(C, Mm, sc=0.3 / math.sqrt(C)); self.Wh = g(C, vocab, sc=1 / math.sqrt(C)) self.P = g(C, C, sc=1 / math.sqrt(C)); self.Q = g(C, C, sc=1 / math.sqrt(C)) # monDEQ monotone op self.mono_m = 1.0 z1 = lambda n, v: torch.full((n,), float(v), device=dev).requires_grad_(True) self.ln1g = z1(C, 1); self.ln1b = z1(C, 0); self.ln2g = z1(C, 1); self.ln2b = z1(C, 0) # LN affine self.fc = g(C, 4 * C, sc=1 / math.sqrt(C)); self.fcb = z1(4 * C, 0) # untied 4x FFN self.pj = g(4 * C, C, sc=1 / math.sqrt(4 * C)); self.pjb = z1(C, 0) self.cmask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev)) self.block = [self.tok, self.pos, self.WQ, self.WK, self.WV, self.WO, self.Wm, self.P, self.Q, self.ln1g, self.ln1b, self.ln2g, self.ln2b, self.fc, self.fcb, self.pj, self.pjb] # in the force self.allp = self.block + [self.Wh] self.capw = (self.WQ, self.WK, self.WV, self.WO, self.Wm, self.Wh, self.fc, self.pj) self.caps = {id(w): w.detach().norm().item() * 3.0 for w in self.capw} def embed(self, idx): return self.tok[idx] + self.pos[None] def attn(self, z): B = z.size(0) q = (z @ self.WQ).view(B, self.T, self.H, self.dh).transpose(1, 2) k = (z @ self.WK).view(B, self.T, self.H, self.dh).transpose(1, 2) v = (z @ self.WV).view(B, self.T, self.H, self.dh).transpose(1, 2) if getattr(self, 'qknorm', False): # Qwen3-style q/k RMSNorm: bounds logits, tames J q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6) k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6) a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh) a = torch.softmax(a.masked_fill(~self.cmask, float('-inf')), -1) return (a @ v).transpose(1, 2).reshape(B, self.T, self.C) @ self.WO def attn_energy(self, z): # conservative LSE attention energy (tied value) B = z.size(0) q = (z @ self.WQ).view(B, self.T, self.H, self.dh).transpose(1, 2) k = (z @ self.WK).view(B, self.T, self.H, self.dh).transpose(1, 2) a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh) a = a.masked_fill(~self.cmask, float('-inf')) return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * a, dim=-1).sum() def Emem(self, z): return -(F.relu(z @ self.Wm) ** 2).sum() def tforce(self, z, xin): # pure thick force (no grad machinery) -> torch.compile h1 = F.layer_norm(z, (self.C,), self.ln1g, self.ln1b) h2 = F.layer_norm(z, (self.C,), self.ln2g, self.ln2b) ff = F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb return -(z - xin) + self.attn(h1) + ff - self.c * z def _noisy(self, t): # optics model: per-pass multiplicative noise if self.fnoise > 0: return t * (1 + self.fnoise * torch.randn_like(t)) return t def nc_force(self, z): # non-conservative part of the force (for AEP/jacreg) if self.attn_mode == 'thick': h1 = F.layer_norm(z, (self.C,), self.ln1g, self.ln1b) h2 = F.layer_norm(z, (self.C,), self.ln2g, self.ln2b) return self._noisy(self.attn(h1) + (F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb)) return self._noisy(self.s * self.attn(z)) def force(self, z, xin, cg=False): with torch.enable_grad(): zr = z if (cg and z.requires_grad) else z.detach().requires_grad_(True) if self.attn_mode == 'thick': # DEQ-transformer block: LN + untied 4x FFN + residual h1 = F.layer_norm(zr, (self.C,), self.ln1g, self.ln1b) h2 = F.layer_norm(zr, (self.C,), self.ln2g, self.ln2b) ff = F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb return -(zr - xin) + self._noisy(self.attn(h1) + ff) - self.c * zr # -c*z: initial contraction if self.attn_mode == 'mono': # monDEQ: structurally-monotone contraction gm, = torch.autograd.grad(self.Emem(zr), zr, create_graph=cg) PtP = self.P.t() @ self.P # PSD -> sym(J) = -(mI+PtP) < 0 (guaranteed) f = (-(self.mono_m * zr + zr @ PtP) + zr @ (self.Q - self.Q.t()).t() + xin - gm + self.s * self.attn(zr)) return f E = 0.5 * ((zr - xin) ** 2).sum() + self.Emem(zr) if self.attn_mode == 'energy': # attention folded into the energy (conservative) E = E + self.attn_energy(zr) + 0.5 * self.c * (zr ** 2).sum() # confinement -> bounded below gz, = torch.autograd.grad(E, zr, create_graph=cg) f = -gz if self.attn_mode == 'real': # non-conservative attention + damping f = f + self.s * (self.attn(zr) - self.c * zr) return f def relax(blk, z, xin, steps, eps): cstep = getattr(blk, '_cstep', None) if cstep is not None and blk.fnoise == 0.0: # compiled pure-thick free-phase fast path with torch.no_grad(): for _ in range(steps): z = cstep(z, xin) return z.detach() for _ in range(steps): with torch.no_grad(): z = z + eps * blk.force(z, xin).detach() return z.detach() def ce(blk, z, y): return F.cross_entropy((z @ blk.Wh).reshape(-1, vocab), y.reshape(-1)) def ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0, corr_every=1, res_gate=0.0, resreg=0.0, eigreg=0.0, eig_margin=1.0): xin0 = blk.embed(idx).detach() zs = relax(blk, xin0.clone(), xin0, T1, eps) res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) res_used = res zT, resT1 = zs, res # the T1 free-phase state (what eval/BPTT use), BEFORE refinement if t1max > T1: # estimator refinement: relax further until tight rnow, t = res, T1 # (controller signal `res` stays measured at T1) while t < t1max and rnow > res_est: zs = relax(blk, zs, xin0, 50, eps); t += 50 rnow = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) res_used = rnow if res_gate > 0 and res_used > res_gate: # validity gate: off-equilibrium the EP update is grads = {} # undefined -> apply ONLY the homeostat (jacreg) and if jacreg > 0: # skip the nudge entirely (fast recovery steps) er = torch.randn_like(zs) with torch.enable_grad(): Jv = torch.autograd.functional.jvp(blk.nc_force, zs.detach(), er, create_graph=True)[1] R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() gr = torch.autograd.grad(R, blk.block, allow_unused=True) grads = {id(p): g for p, g in zip(blk.block, gr) if g is not None} return grads, res def nudge(sign): z = zs.clone() for _ in range(T2): with torch.enable_grad(): zz = z.detach().requires_grad_(True) g, = torch.autograd.grad(ce(blk, zz, y), zz) g = g.clamp(-2.0, 2.0) # clip nudge so it can't blow up relax with torch.no_grad(): f = blk.force(z, xin0).detach() - sign * beta * g if blk.attn_mode in ('real', 'thick'): # AEP correction (full non-conservative part) v = (z - zs).detach() Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1] JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1] corr = Jv - JTv cn, fn = corr.norm(), f.norm() + 1e-8 f = f - (corr * (fn / cn) if cn > fn else corr) z = z + eps * f return z.detach() if holo == 2 and t2sel > 0: # adaptive-T2, phase-batched fast path (validated ==) from holo_ep import holo_a_select2, holo_a_track K = max(1, getattr(blk, 'navg', 1)) # restart-averaging: noise / sqrt(K) acc = None for _ in range(K): if getattr(blk, 'track', False): # common-mode-tracking AEP (loose-tolerant) ai, _ = holo_a_track(blk, zs, xin0, y, hr, t2sel, eps) else: ai, _ = holo_a_select2(blk, zs, xin0, y, hr, t2sel, eps, li=getattr(blk, 'li_avg', 0)) acc = ai if acc is None else acc + ai a = acc / K elif holo > 0 and t2sel > 0: # adaptive-T2 via hindsight snapshot selection from holo_ep import holo_a_select a, _ = holo_a_select(blk, zs, xin0, y, holo, hr, t2sel, eps, corr_every=corr_every) elif holo > 0: # holomorphic nudge (clamp-free, Cauchy readout) from holo_ep import holo_a a, _ = holo_a(blk, zs, xin0, y, holo, hr, T2, eps) else: zp, zm = nudge(+1), nudge(-1) a = ((zm - zp) / (2 * beta)).detach() grads = {} with torch.enable_grad(): xin = blk.embed(idx) # live (for tok/pos grad through clamp) f = blk.force(zs.detach(), xin, cg=True) gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True) for p, gv in zip(blk.block, gblk): grads[id(p)] = gv with torch.enable_grad(): gh, = torch.autograd.grad(ce(blk, zs.detach(), y), blk.Wh) # readout local gradient grads[id(blk.Wh)] = gh if jacreg > 0: # soft Lyapunov: penalize non-conservative Jacobian norm er = torch.randn_like(zs) with torch.enable_grad(): Jv = torch.autograd.functional.jvp(blk.nc_force, zs.detach(), er, create_graph=True)[1] R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() # Hutchinson est of ||J_nc||_F^2 gr = torch.autograd.grad(R, blk.block, allow_unused=True) for p, g in zip(blk.block, gr): if g is not None: grads[id(p)] = g if grads.get(id(p)) is None else grads[id(p)] + g if resreg > 0 and resT1 > 7e-4: # defend z_T1 (BPTT gets this implicitly; EP at z* doesn't) with torch.enable_grad(): Fz = blk.tforce(zT, xin0) # deterministic thick force at z_T1 (params live, zT/xin0 detached) Rr = (eps * Fz).pow(2).sum() / (zT.pow(2).sum() + 1e-9) # ~ (T1 residual)^2 grr = torch.autograd.grad(Rr, blk.block, allow_unused=True) ratio = resreg * min(1.0, resT1 / 2e-2) # ramp 0->resreg as res 7e-4->2e-2, capped gtask = math.sqrt(sum(float((grads[id(p)] ** 2).sum()) for p in blk.block if grads.get(id(p)) is not None) + 1e-20) gres = math.sqrt(sum(float((g ** 2).sum()) for g in grr if g is not None) + 1e-20) lam = ratio * gtask / gres # scale penalty to `ratio` of the task-grad norm for p, g in zip(blk.block, grr): if g is not None: grads[id(p)] = g * lam if grads.get(id(p)) is None else grads[id(p)] + lam * g if eigreg > 0: # #2 v2: TRUE leading map-eigenvalue control (aep 'spectral', soft one-sided) from eig_control import spec_penalty # (omega/numerical-abscissa version refuted 2026-07-03, eig_recheck) ge, _rho, _mu = spec_penalty(blk, zs, eps, blk.c, eigreg, eig_margin, blk.__dict__.setdefault('_eigcache', {})) for pid, g in ge.items(): grads[pid] = g if grads.get(pid) is None else grads[pid] + g return grads, res class Lion(torch.optim.Optimizer): """Chen et al. 2023. Analog-hardware rationale: sign updates = fixed-amplitude pulses (kills device write-nonlinearity), magnitude-noise immune, one momentum cap per weight.""" def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0, lars=False): super().__init__(params, dict(lr=lr, betas=betas, weight_decay=weight_decay, lars=lars)) @torch.no_grad() def step(self): for group in self.param_groups: for p in group['params']: if p.grad is None: continue b1, b2 = group['betas'] st = self.state.setdefault(p, {}) if 'm' not in st: st['m'] = torch.zeros_like(p) u = (b1 * st['m'] + (1 - b1) * p.grad).sign() lr = group['lr'] if group['lars']: # per-tensor trust ratio: one gain line per array lr = lr * (p.norm() / (u.norm() + 1e-12)).item() p.mul_(1 - lr * group['weight_decay']) p.add_(u, alpha=-lr) st['m'].mul_(b2).add_(p.grad, alpha=1 - b2) def bptt_step(blk, idx, y, T1, eps, jacreg=0.0): xin = blk.embed(idx) z = xin.detach().requires_grad_(True) * 0 + xin # init = embedding (keeps graph to emb) for _ in range(T1): z = z + eps * blk.force(z, xin, cg=True) g = torch.autograd.grad(ce(blk, z, y), blk.allp, allow_unused=True) gd = {id(p): gv for p, gv in zip(blk.allp, g)} if jacreg > 0: # same soft Lyapunov penalty as ep mode (fair control) er = torch.randn_like(z) with torch.enable_grad(): Jv = torch.autograd.functional.jvp(blk.nc_force, z.detach(), er, create_graph=True)[1] R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() gr = torch.autograd.grad(R, blk.block, allow_unused=True) for p, gv in zip(blk.block, gr): if gv is not None: gd[id(p)] = gv if gd.get(id(p)) is None else gd[id(p)] + gv return gd @torch.no_grad() def evaluate(blk, T1, eps, nb=8, B=32): tot = 0.0 for _ in range(nb): idx, y = get_batch('val', B, blk.T) xin = blk.embed(idx).detach() z = relax(blk, xin.clone(), xin, T1, eps) tot += ce(blk, z, y).item() return tot / nb def specnorm_weight_items(blk): items = [] qkv = None for name in ('WQKV', 'Wqkv', 'W_qkv', 'qkv'): if hasattr(blk, name): qkv = (name, getattr(blk, name)) break if qkv is not None: items.append(qkv) else: items.extend((name, getattr(blk, name)) for name in ('WQ', 'WK', 'WV') if hasattr(blk, name)) items.extend((name, getattr(blk, name)) for name in ('WO', 'fc', 'pj') if hasattr(blk, name)) return [(name, w) for name, w in items if w.ndim >= 2] @torch.no_grad() def power_sigma(W, u, iters=2): M = W.detach().reshape(W.shape[0], -1) if u is None or u.shape != (M.shape[0],) or u.device != W.device or u.dtype != W.dtype: u = F.normalize(torch.randn(M.shape[0], device=W.device, dtype=W.dtype), dim=0, eps=1e-12) for _ in range(iters): v = F.normalize(M.t().mv(u), dim=0, eps=1e-12) u = F.normalize(M.mv(v), dim=0, eps=1e-12) sigma = u.dot(M.mv(v)).abs() return sigma, u.detach() @torch.no_grad() def project_specnorm_(items, cache, bound): max_before, max_after, clamped = 0.0, 0.0, [] for name, W in items: sigma, u = power_sigma(W, cache.get(id(W))) cache[id(W)] = u before = float(sigma) after = before if before > bound: scale = bound / (before + 1e-12) W.mul_(scale) after = bound clamped.append(name) max_before = max(max_before, before) max_after = max(max_after, after) return max_before, max_after, clamped def main(): ap = argparse.ArgumentParser() ap.add_argument('--mode', choices=['ep', 'bptt'], required=True) ap.add_argument('--steps', type=int, default=2000); ap.add_argument('--B', type=int, default=32) ap.add_argument('--T', type=int, default=64); ap.add_argument('--C', type=int, default=128) ap.add_argument('--H', type=int, default=4); ap.add_argument('--Mm', type=int, default=256) ap.add_argument('--T1', type=int, default=80); ap.add_argument('--T2', type=int, default=15) ap.add_argument('--eps', type=float, default=0.1); ap.add_argument('--beta', type=float, default=0.02) ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--log', type=int, default=100) ap.add_argument('--warmup', type=int, default=0) # linear lr warmup steps (big-model stability) ap.add_argument('--state', type=str, default='') # periodic FULL-state path (weights+opt+sched+step) ap.add_argument('--resume', action='store_true') # resume from --state if it exists (Colab timeouts) ap.add_argument('--save_every', type=int, default=0) # full-state save cadence (0=every --log); set small on Colab ap.add_argument('--c', type=float, default=1.0); ap.add_argument('--capx', type=float, default=3.0) ap.add_argument('--attn_mode', choices=['real', 'energy', 'mono', 'thick'], default='real') ap.add_argument('--ccap', type=float, default=8.0) ap.add_argument('--specnorm', type=float, default=0.0) # hard post-step spectral projection bound; 0=off ap.add_argument('--jacreg', type=float, default=0.0) # Bai 2021: soft Jacobian-norm (Lyapunov) penalty ap.add_argument('--jr_max', type=float, default=16.0) # adaptive jacreg ceiling (ramps up vs residual) ap.add_argument('--res_target', type=float, default=5e-3) # continuous controller target residual ap.add_argument('--jr_floor', type=float, default=None) # controller floor; default=--jacreg (legacy: never off) ap.add_argument('--res_ema', type=float, default=0.0) # EMA on residual signal (0=off); kills controller thrash ap.add_argument('--jr_lrcouple', action='store_true') # anneal the λ floor with the lr schedule (late-drift fix) ap.add_argument('--holo', type=int, default=0) # holomorphic EP: N circle points (0=off) ap.add_argument('--hr', type=float, default=0.02) # holomorphic nudge radius |beta| ap.add_argument('--pema', type=float, default=0.0) # parameter EMA decay (0=off); tames late wander ap.add_argument('--t1max', type=int, default=0) # adaptive free phase: extend up to t1max... ap.add_argument('--res_est', type=float, default=1e-4) # ...until this residual (estimator validity) ap.add_argument('--t2sel', type=int, default=0) # adaptive T2: snapshot-selection cap (0=off) ap.add_argument('--seed', type=int, default=0) ap.add_argument('--data', type=str, default='/tmp/lt_ep/data/shakespeare_char') ap.add_argument('--ckpt', type=str, default='') # save best weights (raw+ema) here ap.add_argument('--corr_every', type=int, default=1) # recompute AEP corr every k nudge steps ap.add_argument('--tf32', action='store_true') # tf32 matmuls (check res floor first!) ap.add_argument('--abort_res', type=float, default=0.1) # kill switch: res above this 100 steps straight ap.add_argument('--res_gate', type=float, default=0.0) # validity gate: skip task grads above this res ap.add_argument('--wsd', type=float, default=0.0) # WSD: hold peak lr, cosine-decay only the last wsd fraction ap.add_argument('--resreg', type=float, default=0.0) # T1-residual penalty: defend z_T1 (cap ratio vs task grad); run res_gate=0 ap.add_argument('--eigreg', type=float, default=0.0) # #2 v2: soft penalty on TRUE |lam|_lead(I+eps*J_F) — aep 'spectral' at C512 ap.add_argument('--eig_margin', type=float, default=0.995) # rho target: penalize |lam|_lead above this (<1 = contracting relaxation map) ap.add_argument('--diag_cos', type=int, default=0) # #1: every N steps, log cos(EP grad, exact BPTT grad) + res ap.add_argument('--fingerprint', action='store_true') # load --init_ckpt, print (res,cos,abscissa,val) fingerprint, exit ap.add_argument('--opt', choices=['adamw', 'lion', 'lionlars', 'sgdm', 'sgdsai'], default='adamw') ap.add_argument('--wd', type=float, default=1e-4) ap.add_argument('--fnoise', type=float, default=0.0) # optics/device twin: mult. noise per force eval ap.add_argument('--wq_bits', type=int, default=0) # weights projected to N bits each step (0=off) ap.add_argument('--wmis', type=float, default=0.0) # static per-device mismatch sigma (0=off) ap.add_argument('--li_avg', type=int, default=0) # lock-in integration window (0=snapshot mode) ap.add_argument('--navg', type=int, default=1) # restart-averaged contrast estimates per update ap.add_argument('--track', action='store_true') # common-mode-tracking AEP correction ap.add_argument('--rt_final', type=float, default=0.0) # anneal res_target to this (0=off), 25%-75% of run ap.add_argument('--nudge_brake', type=float, default=0.0) # kappa: anchor spring during nudge (Tikhonov adjoint) ap.add_argument('--init_ckpt', type=str, default='') # warm-start weights from a saved ckpt ap.add_argument('--qknorm', action='store_true') # Qwen3-style q/k RMSNorm in attention ap.add_argument('--compile', action='store_true') # torch.compile the free-phase relaxation (thick) ap.add_argument('--resinit', type=float, default=1.0) # scale WO,pj at init (ReZero/Fixup: small=near-identity block) cfg = ap.parse_args() if cfg.specnorm < 0: raise SystemExit("--specnorm must be non-negative") global DD, vocab DD = Path(cfg.data) vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size'] if cfg.tf32: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.manual_seed(cfg.seed) blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=cfg.c, attn_mode=cfg.attn_mode) for w in blk.capw: blk.caps[id(w)] = w.detach().norm().item() * cfg.capx if cfg.opt in ('lion', 'lionlars'): opt = Lion(blk.allp, lr=cfg.lr, weight_decay=cfg.wd, lars=(cfg.opt == 'lionlars')) elif cfg.opt == 'sgdm': opt = torch.optim.SGD(blk.allp, lr=cfg.lr, momentum=0.95, weight_decay=cfg.wd, nesterov=True) elif cfg.opt == 'sgdsai': # EP-SaI (SGD-SaI, arXiv:2412.11768, adapted to EP gradients): per-tensor lr from the # init-time gradient SNR, frozen — hardware: one gain line per array, set at calibration. gs = {id(p): [] for p in blk.allp} for _ in range(12): idx0, y0 = get_batch('train', cfg.B, cfg.T) g0, _ = ep_step(blk, idx0, y0, cfg.T1, cfg.T2, cfg.eps, cfg.beta, 0.0, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, 0.0) for p in blk.allp: if g0.get(id(p)) is not None: gs[id(p)].append(g0[id(p)].detach().clone()) sc = {} for p in blk.allp: if gs[id(p)]: S = torch.stack(gs[id(p)]) sc[id(p)] = (S.mean(0).norm() / (S.std(0).norm() + 1e-12)).item() mx = max(sc.values()) print("[sgdsai] per-tensor lr scales: " + " ".join(f"{v/mx:.3f}" for v in sc.values()), flush=True) opt = torch.optim.SGD([dict(params=[p], lr=cfg.lr * sc.get(id(p), mx) / mx) for p in blk.allp], momentum=0.95, weight_decay=cfg.wd, nesterov=True) else: opt = torch.optim.AdamW(blk.allp, lr=cfg.lr, weight_decay=cfg.wd) if cfg.warmup > 0 or cfg.wsd > 0: # warmup -> (WSD hold peak) -> cosine decay _w = cfg.warmup # contraction before large steps kick weights out of basin def _lrl(s): if s < _w: return (s + 1) / _w _ds = int((1 - cfg.wsd) * cfg.steps) if cfg.wsd > 0 else _w # WSD decay-start: hold peak lr until here if s < _ds: return 1.0 p = (s - _ds) / max(1, cfg.steps - _ds) return 0.05 + 0.475 * (1 + math.cos(math.pi * min(1.0, p))) sched = torch.optim.lr_scheduler.LambdaLR(opt, _lrl) else: sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.05) if cfg.init_ckpt: ckl = torch.load(cfg.init_ckpt, map_location=dev) with torch.no_grad(): for p, w in zip(blk.allp, ckl['allp']): p.copy_(w.to(dev)) print(f"[init] warm-start from {cfg.init_ckpt} (step {ckl.get('step')}, best {ckl.get('best', float('nan')):.4f})", flush=True) xin = blk.embed(*get_batch('train', cfg.B, cfg.T)[:1]).detach() r = (relax(blk, relax(blk, xin.clone(), xin, 200, cfg.eps), xin, 1, cfg.eps) - relax(blk, xin.clone(), xin, 200, cfg.eps)).norm().item() print(f"[{cfg.mode}] residual~{r:.1e} | C={cfg.C} H={cfg.H} Mm={cfg.Mm} T1={cfg.T1} T2={cfg.T2}", flush=True) best, t0, jr, rs = 9.9, time.time(), cfg.jacreg, None pema = [p.detach().clone() for p in blk.allp] if cfg.pema > 0 else None badct = 0 blk.fnoise = cfg.fnoise blk.li_avg = cfg.li_avg blk.navg = cfg.navg blk.track = cfg.track blk.nbrake = cfg.nudge_brake blk.qknorm = cfg.qknorm if cfg.resinit != 1.0: # near-identity block at init (contractive) -> stable big-width start with torch.no_grad(): blk.WO.mul_(cfg.resinit); blk.pj.mul_(cfg.resinit) spec_items = specnorm_weight_items(blk) spec_cache = {} if cfg.specnorm > 0: shapes = " ".join(f"{name}{tuple(W.shape)}" for name, W in spec_items) print(f"[specnorm] hard post-step projection: sigma_max <= {cfg.specnorm:g} on {shapes}", flush=True) blk._cstep = None if cfg.compile and cfg.attn_mode == 'thick': _ee = cfg.eps blk._cstep = torch.compile(lambda z, xin: z + _ee * blk.tforce(z, xin)) mis = None if cfg.wmis > 0: # fixed fabrication mismatch (same devices all run) gm = torch.Generator().manual_seed(1234) mis = [(1 + cfg.wmis * torch.randn(p.shape, generator=gm)).clamp(0.2, 5.0).to(dev) for p in blk.allp] hw_on = cfg.wq_bits > 0 or mis is not None def hw_swap(): # measure physics on the imperfect device copy; saved = [p.detach().clone() for p in blk.allp] # masters stay fp32 (program-verify model) with torch.no_grad(): for i, p in enumerate(blk.allp): w = p * mis[i] if mis is not None else p.detach().clone() if cfg.wq_bits > 0: d = w.abs().max() / (2 ** (cfg.wq_bits - 1) - 1) + 1e-12 w = torch.round(w / d) * d p.copy_(w) return saved def hw_restore(saved): with torch.no_grad(): for p, s in zip(blk.allp, saved): p.copy_(s) start_step = 1 if cfg.resume and cfg.state and os.path.exists(cfg.state): # Colab-timeout resume: full state st = torch.load(cfg.state, map_location=dev) with torch.no_grad(): for p, w in zip(blk.allp, st['allp']): p.copy_(w.to(dev)) if pema is not None and st.get('pema') is not None: pema = [s.to(dev) for s in st['pema']] opt.load_state_dict(st['opt']); sched.load_state_dict(st['sched']) start_step = st['step'] + 1; jr = st['jr']; rs = st['rs']; best = st['best'] print(f"[resume] from {cfg.state}: step {start_step}, best {best:.4f}, jr {jr:.1f}", flush=True) def save_state(step): if not cfg.state: return torch.save({'allp': [p.detach().cpu() for p in blk.allp], 'pema': [s.cpu() for s in pema] if pema is not None else None, 'opt': opt.state_dict(), 'sched': sched.state_dict(), 'step': step, 'jr': jr, 'rs': rs, 'best': best}, cfg.state + '.tmp') os.replace(cfg.state + '.tmp', cfg.state) # atomic: survive a mid-write timeout if cfg.fingerprint: # study s2000 vs other ckpts: print the operator's 4-D fingerprint from diag_cos import fingerprint fp = fingerprint(blk, cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel) print(f"[fingerprint] ckpt={cfg.init_ckpt or 'scratch'} | res={fp['res']:.2e} cos(EP,BPTT)={fp['cos']:.4f} " f"rho={fp['rho']:.5f} Re_mu={fp['mu_re']:+.4f} val={fp['val']:.4f}", flush=True) return for step in range(start_step, cfg.steps + 1): idx, y = get_batch('train', cfg.B, cfg.T) if cfg.mode == 'ep': sw = hw_swap() if hw_on else None grads, res = ep_step(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, jr, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, cfg.res_gate, cfg.resreg, cfg.eigreg, cfg.eig_margin) if sw is not None: hw_restore(sw) if cfg.jacreg > 0: # continuous controller: drive residual -> res_target (smooth) flo = cfg.jacreg if cfg.jr_floor is None else cfg.jr_floor if cfg.jr_lrcouple: flo *= sched.get_last_lr()[0] / cfg.lr rtgt = cfg.res_target if cfg.rt_final > 0: # stiffness anneal: tight start -> loose mid/late u = min(1.0, max(0.0, (step / cfg.steps - 0.25) / 0.5)) rtgt = math.exp((1 - u) * math.log(cfg.res_target) + u * math.log(cfg.rt_final)) rs = res if rs is None else cfg.res_ema * rs + (1 - cfg.res_ema) * res jr = min(cfg.jr_max, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / rtgt)))) else: # damping feedback (no jacreg) if res > 1e-3: blk.c = min(cfg.ccap, blk.c * 1.3) elif res < 2e-4: blk.c = max(0.5, blk.c * 0.97) else: grads = bptt_step(blk, idx, y, cfg.T1, cfg.eps, jr if cfg.jacreg > 0 else 0.0) with torch.no_grad(): # is BPTT's optimum contractive? (free-phase residual) xinb = blk.embed(idx).detach() zsb = relax(blk, xinb.clone(), xinb, cfg.T1, cfg.eps) res = (relax(blk, zsb, xinb, 1, cfg.eps) - zsb).norm().item() / (zsb.norm().item() + 1e-9) if cfg.jacreg > 0: # same residual-driven λ controller as ep mode flo = cfg.jacreg if cfg.jr_floor is None else cfg.jr_floor if cfg.jr_lrcouple: flo *= sched.get_last_lr()[0] / cfg.lr rs = res if rs is None else cfg.res_ema * rs + (1 - cfg.res_ema) * res jr = min(cfg.jr_max, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / cfg.res_target)))) badct = badct + 1 if (cfg.abort_res > 0 and res > cfg.abort_res) else 0 if badct >= 100: # containment lost and not recovering: stop, keep best ckpt print(f" ABORT at step {step}: res>{cfg.abort_res} for 100 consecutive steps (best {best:.4f})", flush=True) break ok = all((g is None) or torch.isfinite(g).all() for g in grads.values()) if not ok: print(f" step {step}: non-finite, skip", flush=True); continue opt.zero_grad(set_to_none=True) for p in blk.allp: p.grad = grads.get(id(p)) torch.nn.utils.clip_grad_norm_(blk.allp, 5.0) opt.step() spec_stats = None with torch.no_grad(): if cfg.specnorm > 0: spec_stats = project_specnorm_(spec_items, spec_cache, cfg.specnorm) else: for p in blk.capw: pn = p.norm(); cap = blk.caps[id(p)] if pn > cap: p.mul_(cap / pn) sched.step() if spec_stats is not None and (step == start_step or step % cfg.log == 0): sb, sa, names = spec_stats cname = ",".join(names) if names else "none" print(f" specnorm step {step}: max sigma before={sb:.4f} after={sa:.4f} bound={cfg.specnorm:.4f} clamped={cname}", flush=True) if pema is not None: with torch.no_grad(): for s, p in zip(pema, blk.allp): s.mul_(cfg.pema).add_(p.detach(), alpha=1 - cfg.pema) if cfg.save_every and step % cfg.save_every == 0 and step % cfg.log != 0: save_state(step) # mid-interval state save (Colab: cap worst-case loss) if cfg.diag_cos and step % cfg.diag_cos == 0: # #1: gradient-alignment trajectory (scratch vs warm) from diag_cos import cos_ep_bptt _c, _r = cos_ep_bptt(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel) print(f" [diag] step {step}: cos(EP,BPTT)={_c:.4f} res={_r:.1e}", flush=True) if step % cfg.log == 0: prevb = best sw = hw_swap() if hw_on else None v = evaluate(blk, cfg.T1, cfg.eps) if sw is not None: hw_restore(sw) best = min(best, v) etag = "" if pema is not None: with torch.no_grad(): raw = [p.detach().clone() for p in blk.allp] for p, s in zip(blk.allp, pema): p.copy_(s) ve = evaluate(blk, cfg.T1, cfg.eps) for p, r in zip(blk.allp, raw): p.copy_(r) best = min(best, ve); etag = f" ema={ve:.4f}" if cfg.ckpt and best < prevb: torch.save({'allp': [p.detach().cpu() for p in blk.allp], 'pema': [s.cpu() for s in pema] if pema is not None else None, 'step': step, 'best': best}, cfg.ckpt) print(f"step {step:4d}/{cfg.steps} | val CE {v:.4f}{etag} (best {best:.4f}) | jr={jr:.1f} res={res:.1e} | {step/(time.time()-t0):.2f} it/s", flush=True) save_state(step) # full-state checkpoint each log interval (Colab resume) print(f"[{cfg.mode}] DONE best val CE {best:.4f} (random baseline ln({vocab})={math.log(vocab):.3f})", flush=True) out_dir = Path('runs') out_dir.mkdir(exist_ok=True) json.dump({'mode': cfg.mode, 'best_val_ce': best}, open(out_dir / f'H2_{cfg.mode}.json', 'w')) if __name__ == '__main__': main()