summaryrefslogtreecommitdiff
path: root/ep_run/lt_ep_train.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/lt_ep_train.py')
-rw-r--r--ep_run/lt_ep_train.py630
1 files changed, 630 insertions, 0 deletions
diff --git a/ep_run/lt_ep_train.py b/ep_run/lt_ep_train.py
new file mode 100644
index 0000000..9974bd8
--- /dev/null
+++ b/ep_run/lt_ep_train.py
@@ -0,0 +1,630 @@
+"""option 2 / H2: train a full EP equilibrium transformer block on Shakespeare char-LM.
+
+One block = token state z relaxed to a fixed point of
+ F(z) = -(z - x_in) (input clamp; x_in = embed(idx))
+ - dE_mem/dz (Hopfield memory E_mem = -sum relu(z Wm)^2 ; CONSERVATIVE = FFN)
+ + s*(causal_attn(z) - c*z) (damped causal attention ; NON-conservative)
+Readout logits = z* Whead ; loss = next-token cross-entropy.
+
+Train modes:
+ ep : free + +/-beta nudged equilibria. Block+embed params via vector-field gradient
+ <a, dF/dtheta(z*)> with the AEP correction (clipped) on the attention part; readout
+ head via its own local gradient dCE/dWhead. NO backprop through the relaxation.
+ bptt : backprop through the unrolled relaxation (exact-gradient reference, same architecture).
+Stabilisation from G: damping c, clipped AEP correction, weight-norm caps, best-val checkpoint.
+"""
+import argparse, math, pickle, time, json, os, numpy as np, torch, torch.nn.functional as F
+from pathlib import Path
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+DD = Path('/home/yurenh2/ept/ep_run/data/tinystories_bpe')
+vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
+
+
+def get_batch(split, B, T):
+ data = np.memmap(DD / ('train.bin' if split == 'train' else 'val.bin'), dtype=np.uint16, mode='r')
+ ix = torch.randint(len(data) - T - 1, (B,))
+ x = torch.stack([torch.from_numpy(data[i:i + T].astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy(data[i + 1:i + 1 + T].astype(np.int64)) for i in ix])
+ return x.to(dev), y.to(dev)
+
+
+class EQBlock:
+ def __init__(self, C, H, Mm, T, s=1.0, c=1.0, attn_mode='real', gamma=0.25):
+ self.C, self.H, self.dh, self.s, self.c, self.T = C, H, C // H, s, c, T
+ self.attn_mode, self.gamma = attn_mode, gamma
+ self.fnoise = 0.0 # optics model: mult. noise per force eval
+ g = lambda *sh, sc: (torch.randn(*sh, device=dev) * sc).requires_grad_(True)
+ self.tok = g(vocab, C, sc=0.02); self.pos = g(T, C, sc=0.02)
+ self.WQ = g(C, C, sc=1 / math.sqrt(C)); self.WK = g(C, C, sc=1 / math.sqrt(C))
+ self.WV = g(C, C, sc=1 / math.sqrt(C)); self.WO = g(C, C, sc=1 / math.sqrt(C))
+ self.Wm = g(C, Mm, sc=0.3 / math.sqrt(C)); self.Wh = g(C, vocab, sc=1 / math.sqrt(C))
+ self.P = g(C, C, sc=1 / math.sqrt(C)); self.Q = g(C, C, sc=1 / math.sqrt(C)) # monDEQ monotone op
+ self.mono_m = 1.0
+ z1 = lambda n, v: torch.full((n,), float(v), device=dev).requires_grad_(True)
+ self.ln1g = z1(C, 1); self.ln1b = z1(C, 0); self.ln2g = z1(C, 1); self.ln2b = z1(C, 0) # LN affine
+ self.fc = g(C, 4 * C, sc=1 / math.sqrt(C)); self.fcb = z1(4 * C, 0) # untied 4x FFN
+ self.pj = g(4 * C, C, sc=1 / math.sqrt(4 * C)); self.pjb = z1(C, 0)
+ self.cmask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=dev))
+ self.block = [self.tok, self.pos, self.WQ, self.WK, self.WV, self.WO, self.Wm, self.P, self.Q,
+ self.ln1g, self.ln1b, self.ln2g, self.ln2b, self.fc, self.fcb, self.pj, self.pjb] # in the force
+ self.allp = self.block + [self.Wh]
+ self.capw = (self.WQ, self.WK, self.WV, self.WO, self.Wm, self.Wh, self.fc, self.pj)
+ self.caps = {id(w): w.detach().norm().item() * 3.0 for w in self.capw}
+
+ def embed(self, idx):
+ return self.tok[idx] + self.pos[None]
+
+ def attn(self, z):
+ B = z.size(0)
+ q = (z @ self.WQ).view(B, self.T, self.H, self.dh).transpose(1, 2)
+ k = (z @ self.WK).view(B, self.T, self.H, self.dh).transpose(1, 2)
+ v = (z @ self.WV).view(B, self.T, self.H, self.dh).transpose(1, 2)
+ if getattr(self, 'qknorm', False): # Qwen3-style q/k RMSNorm: bounds logits, tames J
+ q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6)
+ k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6)
+ a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh)
+ a = torch.softmax(a.masked_fill(~self.cmask, float('-inf')), -1)
+ return (a @ v).transpose(1, 2).reshape(B, self.T, self.C) @ self.WO
+
+ def attn_energy(self, z): # conservative LSE attention energy (tied value)
+ B = z.size(0)
+ q = (z @ self.WQ).view(B, self.T, self.H, self.dh).transpose(1, 2)
+ k = (z @ self.WK).view(B, self.T, self.H, self.dh).transpose(1, 2)
+ a = (q @ k.transpose(-2, -1)) / math.sqrt(self.dh)
+ a = a.masked_fill(~self.cmask, float('-inf'))
+ return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * a, dim=-1).sum()
+
+ def Emem(self, z):
+ return -(F.relu(z @ self.Wm) ** 2).sum()
+
+ def tforce(self, z, xin): # pure thick force (no grad machinery) -> torch.compile
+ h1 = F.layer_norm(z, (self.C,), self.ln1g, self.ln1b)
+ h2 = F.layer_norm(z, (self.C,), self.ln2g, self.ln2b)
+ ff = F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb
+ return -(z - xin) + self.attn(h1) + ff - self.c * z
+
+ def _noisy(self, t): # optics model: per-pass multiplicative noise
+ if self.fnoise > 0:
+ return t * (1 + self.fnoise * torch.randn_like(t))
+ return t
+
+ def nc_force(self, z): # non-conservative part of the force (for AEP/jacreg)
+ if self.attn_mode == 'thick':
+ h1 = F.layer_norm(z, (self.C,), self.ln1g, self.ln1b)
+ h2 = F.layer_norm(z, (self.C,), self.ln2g, self.ln2b)
+ return self._noisy(self.attn(h1) + (F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb))
+ return self._noisy(self.s * self.attn(z))
+
+ def force(self, z, xin, cg=False):
+ with torch.enable_grad():
+ zr = z if (cg and z.requires_grad) else z.detach().requires_grad_(True)
+ if self.attn_mode == 'thick': # DEQ-transformer block: LN + untied 4x FFN + residual
+ h1 = F.layer_norm(zr, (self.C,), self.ln1g, self.ln1b)
+ h2 = F.layer_norm(zr, (self.C,), self.ln2g, self.ln2b)
+ ff = F.gelu(h2 @ self.fc + self.fcb, approximate='tanh') @ self.pj + self.pjb
+ return -(zr - xin) + self._noisy(self.attn(h1) + ff) - self.c * zr # -c*z: initial contraction
+ if self.attn_mode == 'mono': # monDEQ: structurally-monotone contraction
+ gm, = torch.autograd.grad(self.Emem(zr), zr, create_graph=cg)
+ PtP = self.P.t() @ self.P # PSD -> sym(J) = -(mI+PtP) < 0 (guaranteed)
+ f = (-(self.mono_m * zr + zr @ PtP) + zr @ (self.Q - self.Q.t()).t()
+ + xin - gm + self.s * self.attn(zr))
+ return f
+ E = 0.5 * ((zr - xin) ** 2).sum() + self.Emem(zr)
+ if self.attn_mode == 'energy': # attention folded into the energy (conservative)
+ E = E + self.attn_energy(zr) + 0.5 * self.c * (zr ** 2).sum() # confinement -> bounded below
+ gz, = torch.autograd.grad(E, zr, create_graph=cg)
+ f = -gz
+ if self.attn_mode == 'real': # non-conservative attention + damping
+ f = f + self.s * (self.attn(zr) - self.c * zr)
+ return f
+
+
+def relax(blk, z, xin, steps, eps):
+ cstep = getattr(blk, '_cstep', None)
+ if cstep is not None and blk.fnoise == 0.0: # compiled pure-thick free-phase fast path
+ with torch.no_grad():
+ for _ in range(steps):
+ z = cstep(z, xin)
+ return z.detach()
+ for _ in range(steps):
+ with torch.no_grad():
+ z = z + eps * blk.force(z, xin).detach()
+ return z.detach()
+
+
+def ce(blk, z, y):
+ return F.cross_entropy((z @ blk.Wh).reshape(-1, vocab), y.reshape(-1))
+
+
+def ep_step(blk, idx, y, T1, T2, eps, beta, jacreg=0.0, holo=0, hr=0.02, t1max=0, res_est=1e-4, t2sel=0,
+ corr_every=1, res_gate=0.0, resreg=0.0, eigreg=0.0, eig_margin=1.0):
+ xin0 = blk.embed(idx).detach()
+ zs = relax(blk, xin0.clone(), xin0, T1, eps)
+ res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9)
+ res_used = res
+ zT, resT1 = zs, res # the T1 free-phase state (what eval/BPTT use), BEFORE refinement
+ if t1max > T1: # estimator refinement: relax further until tight
+ rnow, t = res, T1 # (controller signal `res` stays measured at T1)
+ while t < t1max and rnow > res_est:
+ zs = relax(blk, zs, xin0, 50, eps); t += 50
+ rnow = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9)
+ res_used = rnow
+ if res_gate > 0 and res_used > res_gate: # validity gate: off-equilibrium the EP update is
+ grads = {} # undefined -> apply ONLY the homeostat (jacreg) and
+ if jacreg > 0: # skip the nudge entirely (fast recovery steps)
+ er = torch.randn_like(zs)
+ with torch.enable_grad():
+ Jv = torch.autograd.functional.jvp(blk.nc_force, zs.detach(), er, create_graph=True)[1]
+ R = jacreg * (Jv ** 2).sum() / (er ** 2).sum()
+ gr = torch.autograd.grad(R, blk.block, allow_unused=True)
+ grads = {id(p): g for p, g in zip(blk.block, gr) if g is not None}
+ return grads, res
+ def nudge(sign):
+ z = zs.clone()
+ for _ in range(T2):
+ with torch.enable_grad():
+ zz = z.detach().requires_grad_(True)
+ g, = torch.autograd.grad(ce(blk, zz, y), zz)
+ g = g.clamp(-2.0, 2.0) # clip nudge so it can't blow up relax
+ with torch.no_grad():
+ f = blk.force(z, xin0).detach() - sign * beta * g
+ if blk.attn_mode in ('real', 'thick'): # AEP correction (full non-conservative part)
+ v = (z - zs).detach()
+ Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1]
+ corr = Jv - JTv
+ cn, fn = corr.norm(), f.norm() + 1e-8
+ f = f - (corr * (fn / cn) if cn > fn else corr)
+ z = z + eps * f
+ return z.detach()
+ if holo == 2 and t2sel > 0: # adaptive-T2, phase-batched fast path (validated ==)
+ from holo_ep import holo_a_select2, holo_a_track
+ K = max(1, getattr(blk, 'navg', 1)) # restart-averaging: noise / sqrt(K)
+ acc = None
+ for _ in range(K):
+ if getattr(blk, 'track', False): # common-mode-tracking AEP (loose-tolerant)
+ ai, _ = holo_a_track(blk, zs, xin0, y, hr, t2sel, eps)
+ else:
+ ai, _ = holo_a_select2(blk, zs, xin0, y, hr, t2sel, eps, li=getattr(blk, 'li_avg', 0))
+ acc = ai if acc is None else acc + ai
+ a = acc / K
+ elif holo > 0 and t2sel > 0: # adaptive-T2 via hindsight snapshot selection
+ from holo_ep import holo_a_select
+ a, _ = holo_a_select(blk, zs, xin0, y, holo, hr, t2sel, eps, corr_every=corr_every)
+ elif holo > 0: # holomorphic nudge (clamp-free, Cauchy readout)
+ from holo_ep import holo_a
+ a, _ = holo_a(blk, zs, xin0, y, holo, hr, T2, eps)
+ else:
+ zp, zm = nudge(+1), nudge(-1)
+ a = ((zm - zp) / (2 * beta)).detach()
+ grads = {}
+ with torch.enable_grad():
+ xin = blk.embed(idx) # live (for tok/pos grad through clamp)
+ f = blk.force(zs.detach(), xin, cg=True)
+ gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True)
+ for p, gv in zip(blk.block, gblk):
+ grads[id(p)] = gv
+ with torch.enable_grad():
+ gh, = torch.autograd.grad(ce(blk, zs.detach(), y), blk.Wh) # readout local gradient
+ grads[id(blk.Wh)] = gh
+ if jacreg > 0: # soft Lyapunov: penalize non-conservative Jacobian norm
+ er = torch.randn_like(zs)
+ with torch.enable_grad():
+ Jv = torch.autograd.functional.jvp(blk.nc_force, zs.detach(), er, create_graph=True)[1]
+ R = jacreg * (Jv ** 2).sum() / (er ** 2).sum() # Hutchinson est of ||J_nc||_F^2
+ gr = torch.autograd.grad(R, blk.block, allow_unused=True)
+ for p, g in zip(blk.block, gr):
+ if g is not None:
+ grads[id(p)] = g if grads.get(id(p)) is None else grads[id(p)] + g
+ if resreg > 0 and resT1 > 7e-4: # defend z_T1 (BPTT gets this implicitly; EP at z* doesn't)
+ with torch.enable_grad():
+ Fz = blk.tforce(zT, xin0) # deterministic thick force at z_T1 (params live, zT/xin0 detached)
+ Rr = (eps * Fz).pow(2).sum() / (zT.pow(2).sum() + 1e-9) # ~ (T1 residual)^2
+ grr = torch.autograd.grad(Rr, blk.block, allow_unused=True)
+ ratio = resreg * min(1.0, resT1 / 2e-2) # ramp 0->resreg as res 7e-4->2e-2, capped
+ gtask = math.sqrt(sum(float((grads[id(p)] ** 2).sum()) for p in blk.block if grads.get(id(p)) is not None) + 1e-20)
+ gres = math.sqrt(sum(float((g ** 2).sum()) for g in grr if g is not None) + 1e-20)
+ lam = ratio * gtask / gres # scale penalty to `ratio` of the task-grad norm
+ for p, g in zip(blk.block, grr):
+ if g is not None:
+ grads[id(p)] = g * lam if grads.get(id(p)) is None else grads[id(p)] + lam * g
+ if eigreg > 0: # #2: leading-abscissa control (surgical, one-sided; alt to jacreg)
+ from eig_control import eig_penalty
+ ge, _om = eig_penalty(blk, zs, eigreg, eig_margin, blk.__dict__.setdefault('_eigcache', {}))
+ for pid, g in ge.items():
+ grads[pid] = g if grads.get(pid) is None else grads[pid] + g
+ return grads, res
+
+
+class Lion(torch.optim.Optimizer):
+ """Chen et al. 2023. Analog-hardware rationale: sign updates = fixed-amplitude pulses
+ (kills device write-nonlinearity), magnitude-noise immune, one momentum cap per weight."""
+ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0, lars=False):
+ super().__init__(params, dict(lr=lr, betas=betas, weight_decay=weight_decay, lars=lars))
+
+ @torch.no_grad()
+ def step(self):
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ b1, b2 = group['betas']
+ st = self.state.setdefault(p, {})
+ if 'm' not in st:
+ st['m'] = torch.zeros_like(p)
+ u = (b1 * st['m'] + (1 - b1) * p.grad).sign()
+ lr = group['lr']
+ if group['lars']: # per-tensor trust ratio: one gain line per array
+ lr = lr * (p.norm() / (u.norm() + 1e-12)).item()
+ p.mul_(1 - lr * group['weight_decay'])
+ p.add_(u, alpha=-lr)
+ st['m'].mul_(b2).add_(p.grad, alpha=1 - b2)
+
+
+def bptt_step(blk, idx, y, T1, eps, jacreg=0.0):
+ xin = blk.embed(idx)
+ z = xin.detach().requires_grad_(True) * 0 + xin # init = embedding (keeps graph to emb)
+ for _ in range(T1):
+ z = z + eps * blk.force(z, xin, cg=True)
+ g = torch.autograd.grad(ce(blk, z, y), blk.allp, allow_unused=True)
+ gd = {id(p): gv for p, gv in zip(blk.allp, g)}
+ if jacreg > 0: # same soft Lyapunov penalty as ep mode (fair control)
+ er = torch.randn_like(z)
+ with torch.enable_grad():
+ Jv = torch.autograd.functional.jvp(blk.nc_force, z.detach(), er, create_graph=True)[1]
+ R = jacreg * (Jv ** 2).sum() / (er ** 2).sum()
+ gr = torch.autograd.grad(R, blk.block, allow_unused=True)
+ for p, gv in zip(blk.block, gr):
+ if gv is not None:
+ gd[id(p)] = gv if gd.get(id(p)) is None else gd[id(p)] + gv
+ return gd
+
+
+@torch.no_grad()
+def evaluate(blk, T1, eps, nb=8, B=32):
+ tot = 0.0
+ for _ in range(nb):
+ idx, y = get_batch('val', B, blk.T)
+ xin = blk.embed(idx).detach()
+ z = relax(blk, xin.clone(), xin, T1, eps)
+ tot += ce(blk, z, y).item()
+ return tot / nb
+
+
+def specnorm_weight_items(blk):
+ items = []
+ qkv = None
+ for name in ('WQKV', 'Wqkv', 'W_qkv', 'qkv'):
+ if hasattr(blk, name):
+ qkv = (name, getattr(blk, name))
+ break
+ if qkv is not None:
+ items.append(qkv)
+ else:
+ items.extend((name, getattr(blk, name)) for name in ('WQ', 'WK', 'WV') if hasattr(blk, name))
+ items.extend((name, getattr(blk, name)) for name in ('WO', 'fc', 'pj') if hasattr(blk, name))
+ return [(name, w) for name, w in items if w.ndim >= 2]
+
+
+@torch.no_grad()
+def power_sigma(W, u, iters=2):
+ M = W.detach().reshape(W.shape[0], -1)
+ if u is None or u.shape != (M.shape[0],) or u.device != W.device or u.dtype != W.dtype:
+ u = F.normalize(torch.randn(M.shape[0], device=W.device, dtype=W.dtype), dim=0, eps=1e-12)
+ for _ in range(iters):
+ v = F.normalize(M.t().mv(u), dim=0, eps=1e-12)
+ u = F.normalize(M.mv(v), dim=0, eps=1e-12)
+ sigma = u.dot(M.mv(v)).abs()
+ return sigma, u.detach()
+
+
+@torch.no_grad()
+def project_specnorm_(items, cache, bound):
+ max_before, max_after, clamped = 0.0, 0.0, []
+ for name, W in items:
+ sigma, u = power_sigma(W, cache.get(id(W)))
+ cache[id(W)] = u
+ before = float(sigma)
+ after = before
+ if before > bound:
+ scale = bound / (before + 1e-12)
+ W.mul_(scale)
+ after = bound
+ clamped.append(name)
+ max_before = max(max_before, before)
+ max_after = max(max_after, after)
+ return max_before, max_after, clamped
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--mode', choices=['ep', 'bptt'], required=True)
+ ap.add_argument('--steps', type=int, default=2000); ap.add_argument('--B', type=int, default=32)
+ ap.add_argument('--T', type=int, default=64); ap.add_argument('--C', type=int, default=128)
+ ap.add_argument('--H', type=int, default=4); ap.add_argument('--Mm', type=int, default=256)
+ ap.add_argument('--T1', type=int, default=80); ap.add_argument('--T2', type=int, default=15)
+ ap.add_argument('--eps', type=float, default=0.1); ap.add_argument('--beta', type=float, default=0.02)
+ ap.add_argument('--lr', type=float, default=1e-3); ap.add_argument('--log', type=int, default=100)
+ ap.add_argument('--warmup', type=int, default=0) # linear lr warmup steps (big-model stability)
+ ap.add_argument('--state', type=str, default='') # periodic FULL-state path (weights+opt+sched+step)
+ ap.add_argument('--resume', action='store_true') # resume from --state if it exists (Colab timeouts)
+ ap.add_argument('--save_every', type=int, default=0) # full-state save cadence (0=every --log); set small on Colab
+ ap.add_argument('--c', type=float, default=1.0); ap.add_argument('--capx', type=float, default=3.0)
+ ap.add_argument('--attn_mode', choices=['real', 'energy', 'mono', 'thick'], default='real')
+ ap.add_argument('--ccap', type=float, default=8.0)
+ ap.add_argument('--specnorm', type=float, default=0.0) # hard post-step spectral projection bound; 0=off
+ ap.add_argument('--jacreg', type=float, default=0.0) # Bai 2021: soft Jacobian-norm (Lyapunov) penalty
+ ap.add_argument('--jr_max', type=float, default=16.0) # adaptive jacreg ceiling (ramps up vs residual)
+ ap.add_argument('--res_target', type=float, default=5e-3) # continuous controller target residual
+ ap.add_argument('--jr_floor', type=float, default=None) # controller floor; default=--jacreg (legacy: never off)
+ ap.add_argument('--res_ema', type=float, default=0.0) # EMA on residual signal (0=off); kills controller thrash
+ ap.add_argument('--jr_lrcouple', action='store_true') # anneal the λ floor with the lr schedule (late-drift fix)
+ ap.add_argument('--holo', type=int, default=0) # holomorphic EP: N circle points (0=off)
+ ap.add_argument('--hr', type=float, default=0.02) # holomorphic nudge radius |beta|
+ ap.add_argument('--pema', type=float, default=0.0) # parameter EMA decay (0=off); tames late wander
+ ap.add_argument('--t1max', type=int, default=0) # adaptive free phase: extend up to t1max...
+ ap.add_argument('--res_est', type=float, default=1e-4) # ...until this residual (estimator validity)
+ ap.add_argument('--t2sel', type=int, default=0) # adaptive T2: snapshot-selection cap (0=off)
+ ap.add_argument('--seed', type=int, default=0)
+ ap.add_argument('--data', type=str, default='/tmp/lt_ep/data/shakespeare_char')
+ ap.add_argument('--ckpt', type=str, default='') # save best weights (raw+ema) here
+ ap.add_argument('--corr_every', type=int, default=1) # recompute AEP corr every k nudge steps
+ ap.add_argument('--tf32', action='store_true') # tf32 matmuls (check res floor first!)
+ ap.add_argument('--abort_res', type=float, default=0.1) # kill switch: res above this 100 steps straight
+ ap.add_argument('--res_gate', type=float, default=0.0) # validity gate: skip task grads above this res
+ ap.add_argument('--wsd', type=float, default=0.0) # WSD: hold peak lr, cosine-decay only the last wsd fraction
+ ap.add_argument('--resreg', type=float, default=0.0) # T1-residual penalty: defend z_T1 (cap ratio vs task grad); run res_gate=0
+ ap.add_argument('--eigreg', type=float, default=0.0) # #2: leading-abscissa (numerical-abscissa) control — surgical alt to jacreg
+ ap.add_argument('--eig_margin', type=float, default=1.0) # penalize omega(J_nc) above this (free-phase Hopf boundary ~ 1+c)
+ ap.add_argument('--diag_cos', type=int, default=0) # #1: every N steps, log cos(EP grad, exact BPTT grad) + res
+ ap.add_argument('--fingerprint', action='store_true') # load --init_ckpt, print (res,cos,abscissa,val) fingerprint, exit
+ ap.add_argument('--opt', choices=['adamw', 'lion', 'lionlars', 'sgdm', 'sgdsai'], default='adamw')
+ ap.add_argument('--wd', type=float, default=1e-4)
+ ap.add_argument('--fnoise', type=float, default=0.0) # optics/device twin: mult. noise per force eval
+ ap.add_argument('--wq_bits', type=int, default=0) # weights projected to N bits each step (0=off)
+ ap.add_argument('--wmis', type=float, default=0.0) # static per-device mismatch sigma (0=off)
+ ap.add_argument('--li_avg', type=int, default=0) # lock-in integration window (0=snapshot mode)
+ ap.add_argument('--navg', type=int, default=1) # restart-averaged contrast estimates per update
+ ap.add_argument('--track', action='store_true') # common-mode-tracking AEP correction
+ ap.add_argument('--rt_final', type=float, default=0.0) # anneal res_target to this (0=off), 25%-75% of run
+ ap.add_argument('--nudge_brake', type=float, default=0.0) # kappa: anchor spring during nudge (Tikhonov adjoint)
+ ap.add_argument('--init_ckpt', type=str, default='') # warm-start weights from a saved ckpt
+ ap.add_argument('--qknorm', action='store_true') # Qwen3-style q/k RMSNorm in attention
+ ap.add_argument('--compile', action='store_true') # torch.compile the free-phase relaxation (thick)
+ ap.add_argument('--resinit', type=float, default=1.0) # scale WO,pj at init (ReZero/Fixup: small=near-identity block)
+ cfg = ap.parse_args()
+ if cfg.specnorm < 0:
+ raise SystemExit("--specnorm must be non-negative")
+ global DD, vocab
+ DD = Path(cfg.data)
+ vocab = pickle.load(open(DD / 'meta.pkl', 'rb'))['vocab_size']
+ if cfg.tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.manual_seed(cfg.seed)
+ blk = EQBlock(cfg.C, cfg.H, cfg.Mm, cfg.T, s=1.0, c=cfg.c, attn_mode=cfg.attn_mode)
+ for w in blk.capw:
+ blk.caps[id(w)] = w.detach().norm().item() * cfg.capx
+ if cfg.opt in ('lion', 'lionlars'):
+ opt = Lion(blk.allp, lr=cfg.lr, weight_decay=cfg.wd, lars=(cfg.opt == 'lionlars'))
+ elif cfg.opt == 'sgdm':
+ opt = torch.optim.SGD(blk.allp, lr=cfg.lr, momentum=0.95, weight_decay=cfg.wd, nesterov=True)
+ elif cfg.opt == 'sgdsai':
+ # EP-SaI (SGD-SaI, arXiv:2412.11768, adapted to EP gradients): per-tensor lr from the
+ # init-time gradient SNR, frozen — hardware: one gain line per array, set at calibration.
+ gs = {id(p): [] for p in blk.allp}
+ for _ in range(12):
+ idx0, y0 = get_batch('train', cfg.B, cfg.T)
+ g0, _ = ep_step(blk, idx0, y0, cfg.T1, cfg.T2, cfg.eps, cfg.beta, 0.0, cfg.holo, cfg.hr,
+ cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, 0.0)
+ for p in blk.allp:
+ if g0.get(id(p)) is not None:
+ gs[id(p)].append(g0[id(p)].detach().clone())
+ sc = {}
+ for p in blk.allp:
+ if gs[id(p)]:
+ S = torch.stack(gs[id(p)])
+ sc[id(p)] = (S.mean(0).norm() / (S.std(0).norm() + 1e-12)).item()
+ mx = max(sc.values())
+ print("[sgdsai] per-tensor lr scales: " + " ".join(f"{v/mx:.3f}" for v in sc.values()), flush=True)
+ opt = torch.optim.SGD([dict(params=[p], lr=cfg.lr * sc.get(id(p), mx) / mx) for p in blk.allp],
+ momentum=0.95, weight_decay=cfg.wd, nesterov=True)
+ else:
+ opt = torch.optim.AdamW(blk.allp, lr=cfg.lr, weight_decay=cfg.wd)
+ if cfg.warmup > 0 or cfg.wsd > 0: # warmup -> (WSD hold peak) -> cosine decay
+ _w = cfg.warmup # contraction before large steps kick weights out of basin
+ def _lrl(s):
+ if s < _w:
+ return (s + 1) / _w
+ _ds = int((1 - cfg.wsd) * cfg.steps) if cfg.wsd > 0 else _w # WSD decay-start: hold peak lr until here
+ if s < _ds:
+ return 1.0
+ p = (s - _ds) / max(1, cfg.steps - _ds)
+ return 0.05 + 0.475 * (1 + math.cos(math.pi * min(1.0, p)))
+ sched = torch.optim.lr_scheduler.LambdaLR(opt, _lrl)
+ else:
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.05)
+ if cfg.init_ckpt:
+ ckl = torch.load(cfg.init_ckpt, map_location=dev)
+ with torch.no_grad():
+ for p, w in zip(blk.allp, ckl['allp']):
+ p.copy_(w.to(dev))
+ print(f"[init] warm-start from {cfg.init_ckpt} (step {ckl.get('step')}, best {ckl.get('best', float('nan')):.4f})", flush=True)
+ xin = blk.embed(*get_batch('train', cfg.B, cfg.T)[:1]).detach()
+ r = (relax(blk, relax(blk, xin.clone(), xin, 200, cfg.eps), xin, 1, cfg.eps) - relax(blk, xin.clone(), xin, 200, cfg.eps)).norm().item()
+ print(f"[{cfg.mode}] residual~{r:.1e} | C={cfg.C} H={cfg.H} Mm={cfg.Mm} T1={cfg.T1} T2={cfg.T2}", flush=True)
+ best, t0, jr, rs = 9.9, time.time(), cfg.jacreg, None
+ pema = [p.detach().clone() for p in blk.allp] if cfg.pema > 0 else None
+ badct = 0
+ blk.fnoise = cfg.fnoise
+ blk.li_avg = cfg.li_avg
+ blk.navg = cfg.navg
+ blk.track = cfg.track
+ blk.nbrake = cfg.nudge_brake
+ blk.qknorm = cfg.qknorm
+ if cfg.resinit != 1.0: # near-identity block at init (contractive) -> stable big-width start
+ with torch.no_grad():
+ blk.WO.mul_(cfg.resinit); blk.pj.mul_(cfg.resinit)
+ spec_items = specnorm_weight_items(blk)
+ spec_cache = {}
+ if cfg.specnorm > 0:
+ shapes = " ".join(f"{name}{tuple(W.shape)}" for name, W in spec_items)
+ print(f"[specnorm] hard post-step projection: sigma_max <= {cfg.specnorm:g} on {shapes}", flush=True)
+ blk._cstep = None
+ if cfg.compile and cfg.attn_mode == 'thick':
+ _ee = cfg.eps
+ blk._cstep = torch.compile(lambda z, xin: z + _ee * blk.tforce(z, xin))
+ mis = None
+ if cfg.wmis > 0: # fixed fabrication mismatch (same devices all run)
+ gm = torch.Generator().manual_seed(1234)
+ mis = [(1 + cfg.wmis * torch.randn(p.shape, generator=gm)).clamp(0.2, 5.0).to(dev) for p in blk.allp]
+ hw_on = cfg.wq_bits > 0 or mis is not None
+
+ def hw_swap(): # measure physics on the imperfect device copy;
+ saved = [p.detach().clone() for p in blk.allp] # masters stay fp32 (program-verify model)
+ with torch.no_grad():
+ for i, p in enumerate(blk.allp):
+ w = p * mis[i] if mis is not None else p.detach().clone()
+ if cfg.wq_bits > 0:
+ d = w.abs().max() / (2 ** (cfg.wq_bits - 1) - 1) + 1e-12
+ w = torch.round(w / d) * d
+ p.copy_(w)
+ return saved
+
+ def hw_restore(saved):
+ with torch.no_grad():
+ for p, s in zip(blk.allp, saved):
+ p.copy_(s)
+
+ start_step = 1
+ if cfg.resume and cfg.state and os.path.exists(cfg.state): # Colab-timeout resume: full state
+ st = torch.load(cfg.state, map_location=dev)
+ with torch.no_grad():
+ for p, w in zip(blk.allp, st['allp']):
+ p.copy_(w.to(dev))
+ if pema is not None and st.get('pema') is not None:
+ pema = [s.to(dev) for s in st['pema']]
+ opt.load_state_dict(st['opt']); sched.load_state_dict(st['sched'])
+ start_step = st['step'] + 1; jr = st['jr']; rs = st['rs']; best = st['best']
+ print(f"[resume] from {cfg.state}: step {start_step}, best {best:.4f}, jr {jr:.1f}", flush=True)
+
+ def save_state(step):
+ if not cfg.state:
+ return
+ torch.save({'allp': [p.detach().cpu() for p in blk.allp],
+ 'pema': [s.cpu() for s in pema] if pema is not None else None,
+ 'opt': opt.state_dict(), 'sched': sched.state_dict(),
+ 'step': step, 'jr': jr, 'rs': rs, 'best': best}, cfg.state + '.tmp')
+ os.replace(cfg.state + '.tmp', cfg.state) # atomic: survive a mid-write timeout
+
+ if cfg.fingerprint: # study s2000 vs other ckpts: print the operator's 4-D fingerprint
+ from diag_cos import fingerprint
+ fp = fingerprint(blk, cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.holo, cfg.hr, cfg.t1max, cfg.res_est, cfg.t2sel)
+ print(f"[fingerprint] ckpt={cfg.init_ckpt or 'scratch'} | res={fp['res']:.2e} cos(EP,BPTT)={fp['cos']:.4f} "
+ f"num_abscissa={fp['num_abscissa']:+.4f} val={fp['val']:.4f}", flush=True)
+ return
+ for step in range(start_step, cfg.steps + 1):
+ idx, y = get_batch('train', cfg.B, cfg.T)
+ if cfg.mode == 'ep':
+ sw = hw_swap() if hw_on else None
+ grads, res = ep_step(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, jr, cfg.holo, cfg.hr,
+ cfg.t1max, cfg.res_est, cfg.t2sel, cfg.corr_every, cfg.res_gate, cfg.resreg,
+ cfg.eigreg, cfg.eig_margin)
+ if sw is not None:
+ hw_restore(sw)
+ if cfg.jacreg > 0: # continuous controller: drive residual -> res_target (smooth)
+ flo = cfg.jacreg if cfg.jr_floor is None else cfg.jr_floor
+ if cfg.jr_lrcouple:
+ flo *= sched.get_last_lr()[0] / cfg.lr
+ rtgt = cfg.res_target
+ if cfg.rt_final > 0: # stiffness anneal: tight start -> loose mid/late
+ u = min(1.0, max(0.0, (step / cfg.steps - 0.25) / 0.5))
+ rtgt = math.exp((1 - u) * math.log(cfg.res_target) + u * math.log(cfg.rt_final))
+ rs = res if rs is None else cfg.res_ema * rs + (1 - cfg.res_ema) * res
+ jr = min(cfg.jr_max, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / rtgt))))
+ else: # damping feedback (no jacreg)
+ if res > 1e-3:
+ blk.c = min(cfg.ccap, blk.c * 1.3)
+ elif res < 2e-4:
+ blk.c = max(0.5, blk.c * 0.97)
+ else:
+ grads = bptt_step(blk, idx, y, cfg.T1, cfg.eps, jr if cfg.jacreg > 0 else 0.0)
+ with torch.no_grad(): # is BPTT's optimum contractive? (free-phase residual)
+ xinb = blk.embed(idx).detach()
+ zsb = relax(blk, xinb.clone(), xinb, cfg.T1, cfg.eps)
+ res = (relax(blk, zsb, xinb, 1, cfg.eps) - zsb).norm().item() / (zsb.norm().item() + 1e-9)
+ if cfg.jacreg > 0: # same residual-driven λ controller as ep mode
+ flo = cfg.jacreg if cfg.jr_floor is None else cfg.jr_floor
+ if cfg.jr_lrcouple:
+ flo *= sched.get_last_lr()[0] / cfg.lr
+ rs = res if rs is None else cfg.res_ema * rs + (1 - cfg.res_ema) * res
+ jr = min(cfg.jr_max, max(flo, jr * math.exp(0.3 * math.log((rs + 1e-9) / cfg.res_target))))
+ badct = badct + 1 if (cfg.abort_res > 0 and res > cfg.abort_res) else 0
+ if badct >= 100: # containment lost and not recovering: stop, keep best ckpt
+ print(f" ABORT at step {step}: res>{cfg.abort_res} for 100 consecutive steps (best {best:.4f})", flush=True)
+ break
+ ok = all((g is None) or torch.isfinite(g).all() for g in grads.values())
+ if not ok:
+ print(f" step {step}: non-finite, skip", flush=True); continue
+ opt.zero_grad(set_to_none=True)
+ for p in blk.allp:
+ p.grad = grads.get(id(p))
+ torch.nn.utils.clip_grad_norm_(blk.allp, 5.0)
+ opt.step()
+ spec_stats = None
+ with torch.no_grad():
+ if cfg.specnorm > 0:
+ spec_stats = project_specnorm_(spec_items, spec_cache, cfg.specnorm)
+ else:
+ for p in blk.capw:
+ pn = p.norm(); cap = blk.caps[id(p)]
+ if pn > cap:
+ p.mul_(cap / pn)
+ sched.step()
+ if spec_stats is not None and (step == start_step or step % cfg.log == 0):
+ sb, sa, names = spec_stats
+ cname = ",".join(names) if names else "none"
+ print(f" specnorm step {step}: max sigma before={sb:.4f} after={sa:.4f} bound={cfg.specnorm:.4f} clamped={cname}", flush=True)
+ if pema is not None:
+ with torch.no_grad():
+ for s, p in zip(pema, blk.allp):
+ s.mul_(cfg.pema).add_(p.detach(), alpha=1 - cfg.pema)
+ if cfg.save_every and step % cfg.save_every == 0 and step % cfg.log != 0:
+ save_state(step) # mid-interval state save (Colab: cap worst-case loss)
+ if cfg.diag_cos and step % cfg.diag_cos == 0: # #1: gradient-alignment trajectory (scratch vs warm)
+ from diag_cos import cos_ep_bptt
+ _c, _r = cos_ep_bptt(blk, idx, y, cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.holo, cfg.hr,
+ cfg.t1max, cfg.res_est, cfg.t2sel)
+ print(f" [diag] step {step}: cos(EP,BPTT)={_c:.4f} res={_r:.1e}", flush=True)
+ if step % cfg.log == 0:
+ prevb = best
+ sw = hw_swap() if hw_on else None
+ v = evaluate(blk, cfg.T1, cfg.eps)
+ if sw is not None:
+ hw_restore(sw)
+ best = min(best, v)
+ etag = ""
+ if pema is not None:
+ with torch.no_grad():
+ raw = [p.detach().clone() for p in blk.allp]
+ for p, s in zip(blk.allp, pema):
+ p.copy_(s)
+ ve = evaluate(blk, cfg.T1, cfg.eps)
+ for p, r in zip(blk.allp, raw):
+ p.copy_(r)
+ best = min(best, ve); etag = f" ema={ve:.4f}"
+ if cfg.ckpt and best < prevb:
+ torch.save({'allp': [p.detach().cpu() for p in blk.allp],
+ 'pema': [s.cpu() for s in pema] if pema is not None else None,
+ 'step': step, 'best': best}, cfg.ckpt)
+ print(f"step {step:4d}/{cfg.steps} | val CE {v:.4f}{etag} (best {best:.4f}) | jr={jr:.1f} res={res:.1e} | {step/(time.time()-t0):.2f} it/s", flush=True)
+ save_state(step) # full-state checkpoint each log interval (Colab resume)
+ print(f"[{cfg.mode}] DONE best val CE {best:.4f} (random baseline ln({vocab})={math.log(vocab):.3f})", flush=True)
+ out_dir = Path('runs')
+ out_dir.mkdir(exist_ok=True)
+ json.dump({'mode': cfg.mode, 'best_val_ce': best}, open(out_dir / f'H2_{cfg.mode}.json', 'w'))
+
+
+if __name__ == '__main__':
+ main()