summaryrefslogtreecommitdiff
path: root/ep_run/holo_ep.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/holo_ep.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/holo_ep.py')
-rw-r--r--ep_run/holo_ep.py332
1 files changed, 332 insertions, 0 deletions
diff --git a/ep_run/holo_ep.py b/ep_run/holo_ep.py
new file mode 100644
index 0000000..31054e4
--- /dev/null
+++ b/ep_run/holo_ep.py
@@ -0,0 +1,332 @@
+"""Holomorphic EP (Laborieux & Zenke 2022) for the non-conservative thick block.
+
+Plain EP estimates a = -dz*/dbeta with a 2-point real centered difference: bias O(beta^2) forces
+beta small (0.02), and the estimator noise scales like (equilibration error)/beta. Holomorphic EP
+evaluates the nudged equilibrium at N points on a CIRCLE |beta|=r in the complex plane and reads
+-dz*/dbeta off the discrete Cauchy/Fourier formula a = -Re[(1/(N r)) sum_k e^{-i phi_k} z*(r e^{i phi_k})]:
+bias O(r^N) instead of O(r^2) -> r can be 5-10x larger at equal bias -> the 1/beta noise
+amplification drops by the same factor. Requires the force holomorphically extended to complex
+state: manual LN (non-conjugate variance), softmax (exp ratio), GELU (tanh form, entire).
+The AEP correction carries over unchanged: it is linear in (z - z*) with REAL coefficients, so it
+preserves holomorphy in beta; apply it to real and imaginary parts separately.
+NOTE: no g-clamp and no corr-clip inside the holomorphic nudge (clamps are non-analytic and would
+destroy the O(r^N) bias property); we monitor max|z-z*| instead."""
+import math, torch, torch.nn.functional as F
+from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax
+
+CDT = torch.complex64
+
+
+def cln(z, g, b, eps=1e-5): # holomorphic LayerNorm: NON-conjugate variance
+ mu = z.mean(-1, keepdim=True)
+ v = ((z - mu) ** 2).mean(-1, keepdim=True) # analytic continuation of the real LN
+ return (z - mu) / torch.sqrt(v + eps) * g + b
+
+
+def csoftmax_masked(a, mask): # holomorphic causal softmax via exp ratio
+ c = a.real.amax(-1, keepdim=True) # constant row shift cancels exactly in the ratio
+ w = torch.exp(a - c) * mask # masked entries -> exact 0
+ return w / w.sum(-1, keepdim=True)
+
+
+def cgelu(z): # tanh-form GELU: entire function
+ return 0.5 * z * (1.0 + torch.tanh(0.7978845608028654 * (z + 0.044715 * z ** 3)))
+
+
+def cforce(blk, z, xin): # holomorphic extension of the thick force
+ C, H, dh, T = blk.C, blk.H, blk.dh, blk.T
+ B = z.size(0)
+ h1 = cln(z, blk.ln1g.to(CDT), blk.ln1b.to(CDT))
+ h2 = cln(z, blk.ln2g.to(CDT), blk.ln2b.to(CDT))
+ q = (h1 @ blk.WQ.to(CDT)).view(B, T, H, dh).transpose(1, 2)
+ k = (h1 @ blk.WK.to(CDT)).view(B, T, H, dh).transpose(1, 2)
+ v = (h1 @ blk.WV.to(CDT)).view(B, T, H, dh).transpose(1, 2)
+ if getattr(blk, 'qknorm', False): # match attn()'s q/k RMSNorm (holomorphic: non-conjugate q^2)
+ q = q * (q.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5)
+ k = k * (k.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5)
+ a = (q @ k.transpose(-2, -1)) / math.sqrt(dh)
+ p = csoftmax_masked(a, blk.cmask.to(CDT))
+ att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO.to(CDT)
+ ff = cgelu(h2 @ blk.fc.to(CDT) + blk.fcb.to(CDT)) @ blk.pj.to(CDT) + blk.pjb.to(CDT)
+ return -(z - xin) + att + ff - blk.c * z
+
+
+def cgrad_ce(blk, z, y): # holomorphic dCE/dz = (softmax(z Wh) - Y) Wh^T / NT
+ logits = z @ blk.Wh.to(CDT)
+ c = logits.real.amax(-1, keepdim=True)
+ w = torch.exp(logits - c)
+ p = w / w.sum(-1, keepdim=True)
+ Y = F.one_hot(y, p.size(-1)).to(CDT)
+ return (p - Y) @ blk.Wh.t().to(CDT) / y.numel()
+
+
+def holo_a(blk, zs, xin, y, N, r, T2, eps, corr_on=True):
+ """Nudged phases at beta_k = r e^{2 pi i k / N}; returns (a, max|z - z*|) with
+ a = -Re[(1/(N r)) sum_k e^{-i phi_k} (z_k - z*)] ~ -dz*/dbeta + O(r^N)."""
+ zsc, xc = zs.to(CDT), xin.to(CDT)
+ acc = torch.zeros_like(zsc)
+ mg = 0.0
+ for kk in range(N):
+ ph = complex(math.cos(2 * math.pi * kk / N), math.sin(2 * math.pi * kk / N))
+ beta = r * ph
+ z = zsc.clone()
+ for _ in range(T2):
+ with torch.no_grad():
+ f = cforce(blk, z, xc) - beta * cgrad_ce(blk, z, y)
+ if corr_on: # AEP: J -> J^T, linear & real -> holomorphy kept
+ v = z - zsc
+ Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j
+ JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j
+ if v.imag.abs().max() > 1e-9: # real-axis phases skip the imag solves
+ Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1]
+ JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1]
+ f = f - (Jv - JTv)
+ z = z + eps * f
+ acc = acc + torch.conj(torch.tensor(ph, device=z.device)) * (z - zsc)
+ mg = max(mg, (z - zsc).abs().max().item())
+ a = -(acc / (N * r)).real
+ return a.detach(), mg
+
+
+def holo_a_select(blk, zs, xin, y, N, r, T2max, eps, K=10, exit_mult=5.0, corr_every=1):
+ """Adaptive-T2 by hindsight selection: run nudged phases in lockstep to T2max, snapshot the
+ contrast a_t every K steps, return the snapshot with the smallest increment (most settled).
+ Never worse than short fixed T2 (the settled snapshot exists early too); captures the long-T2
+ win (cos up to ~0.99) when the nudged dynamics are stable; early-exits only on clear blowup —
+ judging by increments of the QUANTITY OF INTEREST, not step sizes, so non-normal transient
+ growth cannot trigger a premature stop."""
+ zsc, xc = zs.to(CDT), xin.to(CDT)
+ ph = [complex(math.cos(2 * math.pi * k / N), math.sin(2 * math.pi * k / N)) for k in range(N)]
+ Z = [zsc.clone() for _ in range(N)]
+ corr = [None] * N
+ a_prev = a_best = None
+ inc_min, t_best = float('inf'), 0
+ for t in range(1, T2max + 1):
+ for k in range(N):
+ with torch.no_grad():
+ f = cforce(blk, Z[k], xc) - (r * ph[k]) * cgrad_ce(blk, Z[k], y)
+ if corr[k] is None or (t - 1) % corr_every == 0: # v moves ~eps/step: stale corr is cheap
+ v = Z[k] - zsc
+ Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j
+ JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j
+ if v.imag.abs().max() > 1e-9:
+ Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1]
+ JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1]
+ corr[k] = Jv - JTv
+ Z[k] = Z[k] + eps * (f - corr[k])
+ if t % K == 0 or t == T2max:
+ acc = sum(torch.conj(torch.tensor(p, device=zs.device)) * (zk - zsc) for p, zk in zip(ph, Z))
+ a_t = -(acc / (N * r)).real
+ if not torch.isfinite(a_t).all():
+ break
+ if a_prev is not None:
+ inc = (a_t - a_prev).norm().item()
+ if inc < inc_min:
+ inc_min, a_best, t_best = inc, a_t, t
+ elif inc > exit_mult * inc_min and t >= 3 * K:
+ break
+ a_prev = a_t
+ if a_best is None:
+ a_best, t_best = a_prev, T2max
+ return a_best.detach(), t_best
+
+
+def rforce(blk, z, xin): # real-axis twin of cforce (tanh-gelu, clamp-free)
+ C, H, dh, T = blk.C, blk.H, blk.dh, blk.T
+ B = z.size(0)
+ h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b)
+ h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b)
+ q = (h1 @ blk.WQ).view(B, T, H, dh).transpose(1, 2)
+ k = (h1 @ blk.WK).view(B, T, H, dh).transpose(1, 2)
+ v = (h1 @ blk.WV).view(B, T, H, dh).transpose(1, 2)
+ if getattr(blk, 'qknorm', False): # match attn()'s q/k RMSNorm in the nudge force
+ q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6)
+ k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6)
+ a = (q @ k.transpose(-2, -1)) / math.sqrt(dh)
+ p = torch.softmax(a.masked_fill(~blk.cmask, float('-inf')), -1)
+ att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO
+ ff = cgelu(h2 @ blk.fc + blk.fcb) @ blk.pj + blk.pjb
+ nc = att + ff
+ if getattr(blk, 'fnoise', 0.0) > 0:
+ nc = nc * (1 + blk.fnoise * torch.randn_like(nc))
+ return -(z - xin) + nc - blk.c * z
+
+
+def rgrad_ce(blk, z, y, denom=None):
+ p = torch.softmax(z @ blk.Wh, -1)
+ return (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ blk.Wh.t() / (denom or y.numel())
+
+
+def holo_a_select2(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0, li=0):
+ """li>0 enables LOCK-IN INTEGRATION mode for noisy (hardware) physics: run the full T2max,
+ EMA the contrast a_t every step with time-constant li — the homodyne integrator that divides
+ persistent per-pass noise by sqrt(window). The hindsight-selection mode (li=0) is for clean
+ physics, where a single most-settled snapshot is optimal."""
+ """N=2 production fast path — mathematically identical to holo_a_select(N=2): both phases are
+ real, so run them PHASE-BATCHED (stack +r/-r along batch) with real tensors and torch.func
+ forward-mode jvp. Halves autograd calls and skips complex arithmetic."""
+ import torch.func as tf
+ B = zs.size(0)
+ Z = torch.cat([zs, zs], 0) # [+r phase | -r phase]
+ X2 = torch.cat([xin, xin], 0)
+ y2 = torch.cat([y, y], 0)
+ sg = torch.cat([torch.full((B, 1, 1), r, device=zs.device),
+ torch.full((B, 1, 1), -r, device=zs.device)], 0)
+ zs2 = torch.cat([zs, zs], 0)
+ fnc = lambda zz: blk.nc_force(zz)
+ a_prev = a_best = a_ema = None
+ inc_min, t_best = float('inf'), 0
+ for t in range(1, T2max + 1):
+ with torch.no_grad():
+ f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) # CE mean over the ORIGINAL batch
+ v = (Z - zs2).contiguous()
+ _, Jv = tf.jvp(fnc, (zs2,), (v,))
+ JTv = tf.vjp(fnc, zs2)[1](v)[0]
+ Z = Z + eps * (f - (Jv - JTv))
+ if li > 0: # lock-in integration (noisy physics)
+ a_t = (Z[B:] - Z[:B]) / (2 * r)
+ if not torch.isfinite(a_t).all():
+ break
+ if t > T2max // 3: # let phases develop, then integrate
+ a_ema = a_t if a_ema is None else a_ema + (a_t - a_ema) / li
+ continue
+ if t % K == 0 or t == T2max:
+ a_t = (Z[B:] - Z[:B]) / (2 * r) # (z_- - z_+)/2r
+ if not torch.isfinite(a_t).all():
+ break
+ if a_prev is not None:
+ inc = (a_t - a_prev).norm().item()
+ if inc < inc_min:
+ inc_min, a_best, t_best = inc, a_t, t
+ elif inc > exit_mult * inc_min and t >= 3 * K:
+ break
+ a_prev = a_t
+ if li > 0:
+ if a_ema is None:
+ a_ema = (Z[B:] - Z[:B]) / (2 * r)
+ return a_ema.detach(), T2max
+ if a_best is None:
+ a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r)
+ t_best = T2max
+ return a_best.detach(), t_best
+
+
+def holo_a_track(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0):
+ """Common-mode-tracking AEP: linearize the antisymmetric correction at the instantaneous
+ common mode of the two phases — exact transposed differential dynamics, loose-tolerant,
+ no compounding linearization error."""
+ import torch.func as tf
+ B = zs.size(0)
+ Z = torch.cat([zs, zs], 0)
+ X2 = torch.cat([xin, xin], 0)
+ y2 = torch.cat([y, y], 0)
+ sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0)
+ fnc = lambda zz: blk.nc_force(zz)
+ a_prev = a_best = None
+ inc_min, t_best = float('inf'), 0
+ zs2a = torch.cat([zs, zs], 0)
+ kappa = getattr(blk, 'nbrake', 0.0)
+ for t in range(1, T2max + 1):
+ with torch.no_grad():
+ zbar = 0.5 * (Z[:B] + Z[B:])
+ zb2 = torch.cat([zbar, zbar], 0)
+ f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel())
+ if kappa > 0: # measurement brake: Tikhonov-regularized adjoint
+ f = f - kappa * (Z - zs2a)
+ v = (Z - zb2).contiguous()
+ _, Jv = tf.jvp(fnc, (zb2,), (v,))
+ JTv = tf.vjp(fnc, zb2)[1](v)[0]
+ Z = Z + eps * (f - (Jv - JTv))
+ if t % K == 0 or t == T2max:
+ a_t = (Z[B:] - Z[:B]) / (2 * r)
+ if not torch.isfinite(a_t).all():
+ break
+ if a_prev is not None:
+ inc = (a_t - a_prev).norm().item()
+ if inc < inc_min:
+ inc_min, a_best, t_best = inc, a_t, t
+ elif inc > exit_mult * inc_min and t >= 3 * K:
+ break
+ a_prev = a_t
+ if a_best is None:
+ a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r)
+ t_best = T2max
+ return a_best.detach(), t_best
+
+
+def holo_a_lockin(blk, zs, xin, y, r, P, ncyc, eps):
+ """True oscillatory EP / lock-in estimator (Laborieux–Zenke taken literally) — the
+ noisy-physics form: ONE trajectory, sinusoidal nudge beta(t)=r·sin(2πt/P), in-phase
+ demodulation over ncyc periods (first period discarded as transient). Single-trajectory =>
+ common-mode noise cancels in the quadrature; v=z−z* stays O(r·response) so the AEP
+ linearization never leaves its window; noise admitted only in the demodulation band."""
+ z = zs.clone()
+ accI = torch.zeros_like(zs)
+ sI = 0.0
+ T = P * (ncyc + 1)
+ for t in range(1, T + 1):
+ s = math.sin(2 * math.pi * t / P)
+ with torch.no_grad():
+ f = rforce(blk, z, xin) - (r * s) * rgrad_ce(blk, z, y, denom=y.numel())
+ v = (z - zs).contiguous()
+ Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1]
+ z = z + eps * (f - (Jv - JTv))
+ if not torch.isfinite(z).all():
+ return None, t
+ if t > P: # demodulate after the transient period
+ accI = accI + z * s
+ sI += s * s
+ return (-(accI / (sI + 1e-12)) / r).detach(), T
+ """Full holomorphic-EP gradient for block params (same VF readout as ep_step)."""
+ xin0 = blk.embed(idx).detach()
+ zs = relax(blk, xin0.clone(), xin0, T1, eps)
+ res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9)
+ a, mg = holo_a(blk, zs, xin0, y, N, r, T2, eps)
+ with torch.enable_grad():
+ xin = blk.embed(idx)
+ f = blk.force(zs.detach(), xin, cg=True)
+ gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True)
+ return {id(p): g for p, g in zip(blk.block, gblk)}, res, mg
+
+
+if __name__ == '__main__':
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ torch.manual_seed(0)
+ B, T, C, H = 16, 64, 128, 4
+ blk = EQBlock(C, H, 256, T, attn_mode='thick')
+ for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')):
+ with torch.no_grad():
+ p.copy_(w.to(dev))
+ print("loaded probe weights (300-step BPTT, thick, c=1)", flush=True)
+
+ groups = {'all': blk.block,
+ 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
+ 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb],
+ 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b],
+ 'emb': [blk.tok, blk.pos]}
+
+ def cos(ga, gb, ps):
+ keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None]
+ if not keep:
+ return float('nan')
+ va = torch.cat([ga[id(p)].reshape(-1) for p in keep])
+ vb = torch.cat([gb[id(p)].reshape(-1) for p in keep])
+ return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()
+
+ hdr = f"{'estimator':>22} {'res':>9} {'max|dz|':>8} " + " ".join(f"{k:>6}" for k in groups)
+ for bi in range(3):
+ idx, y = get_batch('train', B, T)
+ ref = bptt_step(blk, idx, y, 400, 0.1)
+ print(("\n" if bi else "") + hdr, flush=True)
+ for T1 in (150, 400):
+ gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0)
+ print(f"{f'plain ep b=.02 T1={T1}':>22} {res:>9.1e} {'--':>8} "
+ + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
+ gep2, _ = ep_step(blk, idx, y, T1, 20, 0.1, 0.1, 0.0)
+ print(f"{f'plain ep b=.10 T1={T1}':>22} {res:>9.1e} {'--':>8} "
+ + " ".join(f"{cos(gep2, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
+ for (N, r) in ((2, 0.02), (4, 0.05), (4, 0.1), (4, 0.2), (8, 0.2)):
+ gh, res2, mg = holo_grads(blk, idx, y, T1, 20, 0.1, N, r)
+ print(f"{f'holo N={N} r={r} T1={T1}':>22} {res2:>9.1e} {mg:>8.2f} "
+ + " ".join(f"{cos(gh, ref, ps):>6.3f}" for ps in groups.values()), flush=True)