"""Holomorphic EP (Laborieux & Zenke 2022) for the non-conservative thick block. Plain EP estimates a = -dz*/dbeta with a 2-point real centered difference: bias O(beta^2) forces beta small (0.02), and the estimator noise scales like (equilibration error)/beta. Holomorphic EP evaluates the nudged equilibrium at N points on a CIRCLE |beta|=r in the complex plane and reads -dz*/dbeta off the discrete Cauchy/Fourier formula a = -Re[(1/(N r)) sum_k e^{-i phi_k} z*(r e^{i phi_k})]: bias O(r^N) instead of O(r^2) -> r can be 5-10x larger at equal bias -> the 1/beta noise amplification drops by the same factor. Requires the force holomorphically extended to complex state: manual LN (non-conjugate variance), softmax (exp ratio), GELU (tanh form, entire). The AEP correction carries over unchanged: it is linear in (z - z*) with REAL coefficients, so it preserves holomorphy in beta; apply it to real and imaginary parts separately. NOTE: no g-clamp and no corr-clip inside the holomorphic nudge (clamps are non-analytic and would destroy the O(r^N) bias property); we monitor max|z-z*| instead.""" import math, torch, torch.nn.functional as F from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax CDT = torch.complex64 def cln(z, g, b, eps=1e-5): # holomorphic LayerNorm: NON-conjugate variance mu = z.mean(-1, keepdim=True) v = ((z - mu) ** 2).mean(-1, keepdim=True) # analytic continuation of the real LN return (z - mu) / torch.sqrt(v + eps) * g + b def csoftmax_masked(a, mask): # holomorphic causal softmax via exp ratio c = a.real.amax(-1, keepdim=True) # constant row shift cancels exactly in the ratio w = torch.exp(a - c) * mask # masked entries -> exact 0 return w / w.sum(-1, keepdim=True) def cgelu(z): # tanh-form GELU: entire function return 0.5 * z * (1.0 + torch.tanh(0.7978845608028654 * (z + 0.044715 * z ** 3))) def cforce(blk, z, xin): # holomorphic extension of the thick force C, H, dh, T = blk.C, blk.H, blk.dh, blk.T B = z.size(0) h1 = cln(z, blk.ln1g.to(CDT), blk.ln1b.to(CDT)) h2 = cln(z, blk.ln2g.to(CDT), blk.ln2b.to(CDT)) q = (h1 @ blk.WQ.to(CDT)).view(B, T, H, dh).transpose(1, 2) k = (h1 @ blk.WK.to(CDT)).view(B, T, H, dh).transpose(1, 2) v = (h1 @ blk.WV.to(CDT)).view(B, T, H, dh).transpose(1, 2) if getattr(blk, 'qknorm', False): # match attn()'s q/k RMSNorm (holomorphic: non-conjugate q^2) q = q * (q.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5) k = k * (k.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5) a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) p = csoftmax_masked(a, blk.cmask.to(CDT)) att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO.to(CDT) ff = cgelu(h2 @ blk.fc.to(CDT) + blk.fcb.to(CDT)) @ blk.pj.to(CDT) + blk.pjb.to(CDT) return -(z - xin) + att + ff - blk.c * z def cgrad_ce(blk, z, y): # holomorphic dCE/dz = (softmax(z Wh) - Y) Wh^T / NT logits = z @ blk.Wh.to(CDT) c = logits.real.amax(-1, keepdim=True) w = torch.exp(logits - c) p = w / w.sum(-1, keepdim=True) Y = F.one_hot(y, p.size(-1)).to(CDT) return (p - Y) @ blk.Wh.t().to(CDT) / y.numel() def holo_a(blk, zs, xin, y, N, r, T2, eps, corr_on=True): """Nudged phases at beta_k = r e^{2 pi i k / N}; returns (a, max|z - z*|) with a = -Re[(1/(N r)) sum_k e^{-i phi_k} (z_k - z*)] ~ -dz*/dbeta + O(r^N).""" zsc, xc = zs.to(CDT), xin.to(CDT) acc = torch.zeros_like(zsc) mg = 0.0 for kk in range(N): ph = complex(math.cos(2 * math.pi * kk / N), math.sin(2 * math.pi * kk / N)) beta = r * ph z = zsc.clone() for _ in range(T2): with torch.no_grad(): f = cforce(blk, z, xc) - beta * cgrad_ce(blk, z, y) if corr_on: # AEP: J -> J^T, linear & real -> holomorphy kept v = z - zsc Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j if v.imag.abs().max() > 1e-9: # real-axis phases skip the imag solves Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1] JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1] f = f - (Jv - JTv) z = z + eps * f acc = acc + torch.conj(torch.tensor(ph, device=z.device)) * (z - zsc) mg = max(mg, (z - zsc).abs().max().item()) a = -(acc / (N * r)).real return a.detach(), mg def holo_a_select(blk, zs, xin, y, N, r, T2max, eps, K=10, exit_mult=5.0, corr_every=1): """Adaptive-T2 by hindsight selection: run nudged phases in lockstep to T2max, snapshot the contrast a_t every K steps, return the snapshot with the smallest increment (most settled). Never worse than short fixed T2 (the settled snapshot exists early too); captures the long-T2 win (cos up to ~0.99) when the nudged dynamics are stable; early-exits only on clear blowup — judging by increments of the QUANTITY OF INTEREST, not step sizes, so non-normal transient growth cannot trigger a premature stop.""" zsc, xc = zs.to(CDT), xin.to(CDT) ph = [complex(math.cos(2 * math.pi * k / N), math.sin(2 * math.pi * k / N)) for k in range(N)] Z = [zsc.clone() for _ in range(N)] corr = [None] * N a_prev = a_best = None inc_min, t_best = float('inf'), 0 for t in range(1, T2max + 1): for k in range(N): with torch.no_grad(): f = cforce(blk, Z[k], xc) - (r * ph[k]) * cgrad_ce(blk, Z[k], y) if corr[k] is None or (t - 1) % corr_every == 0: # v moves ~eps/step: stale corr is cheap v = Z[k] - zsc Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j if v.imag.abs().max() > 1e-9: Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1] JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1] corr[k] = Jv - JTv Z[k] = Z[k] + eps * (f - corr[k]) if t % K == 0 or t == T2max: acc = sum(torch.conj(torch.tensor(p, device=zs.device)) * (zk - zsc) for p, zk in zip(ph, Z)) a_t = -(acc / (N * r)).real if not torch.isfinite(a_t).all(): break if a_prev is not None: inc = (a_t - a_prev).norm().item() if inc < inc_min: inc_min, a_best, t_best = inc, a_t, t elif inc > exit_mult * inc_min and t >= 3 * K: break a_prev = a_t if a_best is None: a_best, t_best = a_prev, T2max return a_best.detach(), t_best def rforce(blk, z, xin): # real-axis twin of cforce (tanh-gelu, clamp-free) C, H, dh, T = blk.C, blk.H, blk.dh, blk.T B = z.size(0) h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b) h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b) q = (h1 @ blk.WQ).view(B, T, H, dh).transpose(1, 2) k = (h1 @ blk.WK).view(B, T, H, dh).transpose(1, 2) v = (h1 @ blk.WV).view(B, T, H, dh).transpose(1, 2) if getattr(blk, 'qknorm', False): # match attn()'s q/k RMSNorm in the nudge force q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6) k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6) a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) p = torch.softmax(a.masked_fill(~blk.cmask, float('-inf')), -1) att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO ff = cgelu(h2 @ blk.fc + blk.fcb) @ blk.pj + blk.pjb nc = att + ff if getattr(blk, 'fnoise', 0.0) > 0: nc = nc * (1 + blk.fnoise * torch.randn_like(nc)) return -(z - xin) + nc - blk.c * z def rgrad_ce(blk, z, y, denom=None): p = torch.softmax(z @ blk.Wh, -1) return (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ blk.Wh.t() / (denom or y.numel()) def holo_a_select2(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0, li=0): """li>0 enables LOCK-IN INTEGRATION mode for noisy (hardware) physics: run the full T2max, EMA the contrast a_t every step with time-constant li — the homodyne integrator that divides persistent per-pass noise by sqrt(window). The hindsight-selection mode (li=0) is for clean physics, where a single most-settled snapshot is optimal.""" """N=2 production fast path — mathematically identical to holo_a_select(N=2): both phases are real, so run them PHASE-BATCHED (stack +r/-r along batch) with real tensors and torch.func forward-mode jvp. Halves autograd calls and skips complex arithmetic.""" import torch.func as tf B = zs.size(0) Z = torch.cat([zs, zs], 0) # [+r phase | -r phase] X2 = torch.cat([xin, xin], 0) y2 = torch.cat([y, y], 0) sg = torch.cat([torch.full((B, 1, 1), r, device=zs.device), torch.full((B, 1, 1), -r, device=zs.device)], 0) zs2 = torch.cat([zs, zs], 0) fnc = lambda zz: blk.nc_force(zz) a_prev = a_best = a_ema = None inc_min, t_best = float('inf'), 0 for t in range(1, T2max + 1): with torch.no_grad(): f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) # CE mean over the ORIGINAL batch v = (Z - zs2).contiguous() _, Jv = tf.jvp(fnc, (zs2,), (v,)) JTv = tf.vjp(fnc, zs2)[1](v)[0] Z = Z + eps * (f - (Jv - JTv)) if li > 0: # lock-in integration (noisy physics) a_t = (Z[B:] - Z[:B]) / (2 * r) if not torch.isfinite(a_t).all(): break if t > T2max // 3: # let phases develop, then integrate a_ema = a_t if a_ema is None else a_ema + (a_t - a_ema) / li continue if t % K == 0 or t == T2max: a_t = (Z[B:] - Z[:B]) / (2 * r) # (z_- - z_+)/2r if not torch.isfinite(a_t).all(): break if a_prev is not None: inc = (a_t - a_prev).norm().item() if inc < inc_min: inc_min, a_best, t_best = inc, a_t, t elif inc > exit_mult * inc_min and t >= 3 * K: break a_prev = a_t if li > 0: if a_ema is None: a_ema = (Z[B:] - Z[:B]) / (2 * r) return a_ema.detach(), T2max if a_best is None: a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r) t_best = T2max return a_best.detach(), t_best def holo_a_track(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0): """Common-mode-tracking AEP: linearize the antisymmetric correction at the instantaneous common mode of the two phases — exact transposed differential dynamics, loose-tolerant, no compounding linearization error.""" import torch.func as tf B = zs.size(0) Z = torch.cat([zs, zs], 0) X2 = torch.cat([xin, xin], 0) y2 = torch.cat([y, y], 0) sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) fnc = lambda zz: blk.nc_force(zz) a_prev = a_best = None inc_min, t_best = float('inf'), 0 zs2a = torch.cat([zs, zs], 0) kappa = getattr(blk, 'nbrake', 0.0) for t in range(1, T2max + 1): with torch.no_grad(): zbar = 0.5 * (Z[:B] + Z[B:]) zb2 = torch.cat([zbar, zbar], 0) f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) if kappa > 0: # measurement brake: Tikhonov-regularized adjoint f = f - kappa * (Z - zs2a) v = (Z - zb2).contiguous() _, Jv = tf.jvp(fnc, (zb2,), (v,)) JTv = tf.vjp(fnc, zb2)[1](v)[0] Z = Z + eps * (f - (Jv - JTv)) if t % K == 0 or t == T2max: a_t = (Z[B:] - Z[:B]) / (2 * r) if not torch.isfinite(a_t).all(): break if a_prev is not None: inc = (a_t - a_prev).norm().item() if inc < inc_min: inc_min, a_best, t_best = inc, a_t, t elif inc > exit_mult * inc_min and t >= 3 * K: break a_prev = a_t if a_best is None: a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r) t_best = T2max return a_best.detach(), t_best def holo_a_lockin(blk, zs, xin, y, r, P, ncyc, eps): """True oscillatory EP / lock-in estimator (Laborieux–Zenke taken literally) — the noisy-physics form: ONE trajectory, sinusoidal nudge beta(t)=r·sin(2πt/P), in-phase demodulation over ncyc periods (first period discarded as transient). Single-trajectory => common-mode noise cancels in the quadrature; v=z−z* stays O(r·response) so the AEP linearization never leaves its window; noise admitted only in the demodulation band.""" z = zs.clone() accI = torch.zeros_like(zs) sI = 0.0 T = P * (ncyc + 1) for t in range(1, T + 1): s = math.sin(2 * math.pi * t / P) with torch.no_grad(): f = rforce(blk, z, xin) - (r * s) * rgrad_ce(blk, z, y, denom=y.numel()) v = (z - zs).contiguous() Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1] JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1] z = z + eps * (f - (Jv - JTv)) if not torch.isfinite(z).all(): return None, t if t > P: # demodulate after the transient period accI = accI + z * s sI += s * s return (-(accI / (sI + 1e-12)) / r).detach(), T """Full holomorphic-EP gradient for block params (same VF readout as ep_step).""" xin0 = blk.embed(idx).detach() zs = relax(blk, xin0.clone(), xin0, T1, eps) res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) a, mg = holo_a(blk, zs, xin0, y, N, r, T2, eps) with torch.enable_grad(): xin = blk.embed(idx) f = blk.force(zs.detach(), xin, cg=True) gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True) return {id(p): g for p, g in zip(blk.block, gblk)}, res, mg if __name__ == '__main__': dev = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) B, T, C, H = 16, 64, 128, 4 blk = EQBlock(C, H, 256, T, attn_mode='thick') for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')): with torch.no_grad(): p.copy_(w.to(dev)) print("loaded probe weights (300-step BPTT, thick, c=1)", flush=True) groups = {'all': blk.block, 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], 'emb': [blk.tok, blk.pos]} def cos(ga, gb, ps): keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] if not keep: return float('nan') va = torch.cat([ga[id(p)].reshape(-1) for p in keep]) vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() hdr = f"{'estimator':>22} {'res':>9} {'max|dz|':>8} " + " ".join(f"{k:>6}" for k in groups) for bi in range(3): idx, y = get_batch('train', B, T) ref = bptt_step(blk, idx, y, 400, 0.1) print(("\n" if bi else "") + hdr, flush=True) for T1 in (150, 400): gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0) print(f"{f'plain ep b=.02 T1={T1}':>22} {res:>9.1e} {'--':>8} " + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) gep2, _ = ep_step(blk, idx, y, T1, 20, 0.1, 0.1, 0.0) print(f"{f'plain ep b=.10 T1={T1}':>22} {res:>9.1e} {'--':>8} " + " ".join(f"{cos(gep2, ref, ps):>6.3f}" for ps in groups.values()), flush=True) for (N, r) in ((2, 0.02), (4, 0.05), (4, 0.1), (4, 0.2), (8, 0.2)): gh, res2, mg = holo_grads(blk, idx, y, T1, 20, 0.1, N, r) print(f"{f'holo N={N} r={r} T1={T1}':>22} {res2:>9.1e} {mg:>8.2f} " + " ".join(f"{cos(gh, ref, ps):>6.3f}" for ps in groups.values()), flush=True)