diff options
Diffstat (limited to 'ep_run/holo_ep.py')
| -rw-r--r-- | ep_run/holo_ep.py | 332 |
1 files changed, 332 insertions, 0 deletions
diff --git a/ep_run/holo_ep.py b/ep_run/holo_ep.py new file mode 100644 index 0000000..31054e4 --- /dev/null +++ b/ep_run/holo_ep.py @@ -0,0 +1,332 @@ +"""Holomorphic EP (Laborieux & Zenke 2022) for the non-conservative thick block. + +Plain EP estimates a = -dz*/dbeta with a 2-point real centered difference: bias O(beta^2) forces +beta small (0.02), and the estimator noise scales like (equilibration error)/beta. Holomorphic EP +evaluates the nudged equilibrium at N points on a CIRCLE |beta|=r in the complex plane and reads +-dz*/dbeta off the discrete Cauchy/Fourier formula a = -Re[(1/(N r)) sum_k e^{-i phi_k} z*(r e^{i phi_k})]: +bias O(r^N) instead of O(r^2) -> r can be 5-10x larger at equal bias -> the 1/beta noise +amplification drops by the same factor. Requires the force holomorphically extended to complex +state: manual LN (non-conjugate variance), softmax (exp ratio), GELU (tanh form, entire). +The AEP correction carries over unchanged: it is linear in (z - z*) with REAL coefficients, so it +preserves holomorphy in beta; apply it to real and imaginary parts separately. +NOTE: no g-clamp and no corr-clip inside the holomorphic nudge (clamps are non-analytic and would +destroy the O(r^N) bias property); we monitor max|z-z*| instead.""" +import math, torch, torch.nn.functional as F +from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax + +CDT = torch.complex64 + + +def cln(z, g, b, eps=1e-5): # holomorphic LayerNorm: NON-conjugate variance + mu = z.mean(-1, keepdim=True) + v = ((z - mu) ** 2).mean(-1, keepdim=True) # analytic continuation of the real LN + return (z - mu) / torch.sqrt(v + eps) * g + b + + +def csoftmax_masked(a, mask): # holomorphic causal softmax via exp ratio + c = a.real.amax(-1, keepdim=True) # constant row shift cancels exactly in the ratio + w = torch.exp(a - c) * mask # masked entries -> exact 0 + return w / w.sum(-1, keepdim=True) + + +def cgelu(z): # tanh-form GELU: entire function + return 0.5 * z * (1.0 + torch.tanh(0.7978845608028654 * (z + 0.044715 * z ** 3))) + + +def cforce(blk, z, xin): # holomorphic extension of the thick force + C, H, dh, T = blk.C, blk.H, blk.dh, blk.T + B = z.size(0) + h1 = cln(z, blk.ln1g.to(CDT), blk.ln1b.to(CDT)) + h2 = cln(z, blk.ln2g.to(CDT), blk.ln2b.to(CDT)) + q = (h1 @ blk.WQ.to(CDT)).view(B, T, H, dh).transpose(1, 2) + k = (h1 @ blk.WK.to(CDT)).view(B, T, H, dh).transpose(1, 2) + v = (h1 @ blk.WV.to(CDT)).view(B, T, H, dh).transpose(1, 2) + if getattr(blk, 'qknorm', False): # match attn()'s q/k RMSNorm (holomorphic: non-conjugate q^2) + q = q * (q.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5) + k = k * (k.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5) + a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) + p = csoftmax_masked(a, blk.cmask.to(CDT)) + att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO.to(CDT) + ff = cgelu(h2 @ blk.fc.to(CDT) + blk.fcb.to(CDT)) @ blk.pj.to(CDT) + blk.pjb.to(CDT) + return -(z - xin) + att + ff - blk.c * z + + +def cgrad_ce(blk, z, y): # holomorphic dCE/dz = (softmax(z Wh) - Y) Wh^T / NT + logits = z @ blk.Wh.to(CDT) + c = logits.real.amax(-1, keepdim=True) + w = torch.exp(logits - c) + p = w / w.sum(-1, keepdim=True) + Y = F.one_hot(y, p.size(-1)).to(CDT) + return (p - Y) @ blk.Wh.t().to(CDT) / y.numel() + + +def holo_a(blk, zs, xin, y, N, r, T2, eps, corr_on=True): + """Nudged phases at beta_k = r e^{2 pi i k / N}; returns (a, max|z - z*|) with + a = -Re[(1/(N r)) sum_k e^{-i phi_k} (z_k - z*)] ~ -dz*/dbeta + O(r^N).""" + zsc, xc = zs.to(CDT), xin.to(CDT) + acc = torch.zeros_like(zsc) + mg = 0.0 + for kk in range(N): + ph = complex(math.cos(2 * math.pi * kk / N), math.sin(2 * math.pi * kk / N)) + beta = r * ph + z = zsc.clone() + for _ in range(T2): + with torch.no_grad(): + f = cforce(blk, z, xc) - beta * cgrad_ce(blk, z, y) + if corr_on: # AEP: J -> J^T, linear & real -> holomorphy kept + v = z - zsc + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j + JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j + if v.imag.abs().max() > 1e-9: # real-axis phases skip the imag solves + Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1] + JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1] + f = f - (Jv - JTv) + z = z + eps * f + acc = acc + torch.conj(torch.tensor(ph, device=z.device)) * (z - zsc) + mg = max(mg, (z - zsc).abs().max().item()) + a = -(acc / (N * r)).real + return a.detach(), mg + + +def holo_a_select(blk, zs, xin, y, N, r, T2max, eps, K=10, exit_mult=5.0, corr_every=1): + """Adaptive-T2 by hindsight selection: run nudged phases in lockstep to T2max, snapshot the + contrast a_t every K steps, return the snapshot with the smallest increment (most settled). + Never worse than short fixed T2 (the settled snapshot exists early too); captures the long-T2 + win (cos up to ~0.99) when the nudged dynamics are stable; early-exits only on clear blowup — + judging by increments of the QUANTITY OF INTEREST, not step sizes, so non-normal transient + growth cannot trigger a premature stop.""" + zsc, xc = zs.to(CDT), xin.to(CDT) + ph = [complex(math.cos(2 * math.pi * k / N), math.sin(2 * math.pi * k / N)) for k in range(N)] + Z = [zsc.clone() for _ in range(N)] + corr = [None] * N + a_prev = a_best = None + inc_min, t_best = float('inf'), 0 + for t in range(1, T2max + 1): + for k in range(N): + with torch.no_grad(): + f = cforce(blk, Z[k], xc) - (r * ph[k]) * cgrad_ce(blk, Z[k], y) + if corr[k] is None or (t - 1) % corr_every == 0: # v moves ~eps/step: stale corr is cheap + v = Z[k] - zsc + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j + JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j + if v.imag.abs().max() > 1e-9: + Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1] + JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1] + corr[k] = Jv - JTv + Z[k] = Z[k] + eps * (f - corr[k]) + if t % K == 0 or t == T2max: + acc = sum(torch.conj(torch.tensor(p, device=zs.device)) * (zk - zsc) for p, zk in zip(ph, Z)) + a_t = -(acc / (N * r)).real + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + if a_best is None: + a_best, t_best = a_prev, T2max + return a_best.detach(), t_best + + +def rforce(blk, z, xin): # real-axis twin of cforce (tanh-gelu, clamp-free) + C, H, dh, T = blk.C, blk.H, blk.dh, blk.T + B = z.size(0) + h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b) + h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b) + q = (h1 @ blk.WQ).view(B, T, H, dh).transpose(1, 2) + k = (h1 @ blk.WK).view(B, T, H, dh).transpose(1, 2) + v = (h1 @ blk.WV).view(B, T, H, dh).transpose(1, 2) + if getattr(blk, 'qknorm', False): # match attn()'s q/k RMSNorm in the nudge force + q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6) + k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6) + a = (q @ k.transpose(-2, -1)) / math.sqrt(dh) + p = torch.softmax(a.masked_fill(~blk.cmask, float('-inf')), -1) + att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO + ff = cgelu(h2 @ blk.fc + blk.fcb) @ blk.pj + blk.pjb + nc = att + ff + if getattr(blk, 'fnoise', 0.0) > 0: + nc = nc * (1 + blk.fnoise * torch.randn_like(nc)) + return -(z - xin) + nc - blk.c * z + + +def rgrad_ce(blk, z, y, denom=None): + p = torch.softmax(z @ blk.Wh, -1) + return (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ blk.Wh.t() / (denom or y.numel()) + + +def holo_a_select2(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0, li=0): + """li>0 enables LOCK-IN INTEGRATION mode for noisy (hardware) physics: run the full T2max, + EMA the contrast a_t every step with time-constant li — the homodyne integrator that divides + persistent per-pass noise by sqrt(window). The hindsight-selection mode (li=0) is for clean + physics, where a single most-settled snapshot is optimal.""" + """N=2 production fast path — mathematically identical to holo_a_select(N=2): both phases are + real, so run them PHASE-BATCHED (stack +r/-r along batch) with real tensors and torch.func + forward-mode jvp. Halves autograd calls and skips complex arithmetic.""" + import torch.func as tf + B = zs.size(0) + Z = torch.cat([zs, zs], 0) # [+r phase | -r phase] + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B, 1, 1), r, device=zs.device), + torch.full((B, 1, 1), -r, device=zs.device)], 0) + zs2 = torch.cat([zs, zs], 0) + fnc = lambda zz: blk.nc_force(zz) + a_prev = a_best = a_ema = None + inc_min, t_best = float('inf'), 0 + for t in range(1, T2max + 1): + with torch.no_grad(): + f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) # CE mean over the ORIGINAL batch + v = (Z - zs2).contiguous() + _, Jv = tf.jvp(fnc, (zs2,), (v,)) + JTv = tf.vjp(fnc, zs2)[1](v)[0] + Z = Z + eps * (f - (Jv - JTv)) + if li > 0: # lock-in integration (noisy physics) + a_t = (Z[B:] - Z[:B]) / (2 * r) + if not torch.isfinite(a_t).all(): + break + if t > T2max // 3: # let phases develop, then integrate + a_ema = a_t if a_ema is None else a_ema + (a_t - a_ema) / li + continue + if t % K == 0 or t == T2max: + a_t = (Z[B:] - Z[:B]) / (2 * r) # (z_- - z_+)/2r + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + if li > 0: + if a_ema is None: + a_ema = (Z[B:] - Z[:B]) / (2 * r) + return a_ema.detach(), T2max + if a_best is None: + a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r) + t_best = T2max + return a_best.detach(), t_best + + +def holo_a_track(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0): + """Common-mode-tracking AEP: linearize the antisymmetric correction at the instantaneous + common mode of the two phases — exact transposed differential dynamics, loose-tolerant, + no compounding linearization error.""" + import torch.func as tf + B = zs.size(0) + Z = torch.cat([zs, zs], 0) + X2 = torch.cat([xin, xin], 0) + y2 = torch.cat([y, y], 0) + sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0) + fnc = lambda zz: blk.nc_force(zz) + a_prev = a_best = None + inc_min, t_best = float('inf'), 0 + zs2a = torch.cat([zs, zs], 0) + kappa = getattr(blk, 'nbrake', 0.0) + for t in range(1, T2max + 1): + with torch.no_grad(): + zbar = 0.5 * (Z[:B] + Z[B:]) + zb2 = torch.cat([zbar, zbar], 0) + f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel()) + if kappa > 0: # measurement brake: Tikhonov-regularized adjoint + f = f - kappa * (Z - zs2a) + v = (Z - zb2).contiguous() + _, Jv = tf.jvp(fnc, (zb2,), (v,)) + JTv = tf.vjp(fnc, zb2)[1](v)[0] + Z = Z + eps * (f - (Jv - JTv)) + if t % K == 0 or t == T2max: + a_t = (Z[B:] - Z[:B]) / (2 * r) + if not torch.isfinite(a_t).all(): + break + if a_prev is not None: + inc = (a_t - a_prev).norm().item() + if inc < inc_min: + inc_min, a_best, t_best = inc, a_t, t + elif inc > exit_mult * inc_min and t >= 3 * K: + break + a_prev = a_t + if a_best is None: + a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r) + t_best = T2max + return a_best.detach(), t_best + + +def holo_a_lockin(blk, zs, xin, y, r, P, ncyc, eps): + """True oscillatory EP / lock-in estimator (Laborieux–Zenke taken literally) — the + noisy-physics form: ONE trajectory, sinusoidal nudge beta(t)=r·sin(2πt/P), in-phase + demodulation over ncyc periods (first period discarded as transient). Single-trajectory => + common-mode noise cancels in the quadrature; v=z−z* stays O(r·response) so the AEP + linearization never leaves its window; noise admitted only in the demodulation band.""" + z = zs.clone() + accI = torch.zeros_like(zs) + sI = 0.0 + T = P * (ncyc + 1) + for t in range(1, T + 1): + s = math.sin(2 * math.pi * t / P) + with torch.no_grad(): + f = rforce(blk, z, xin) - (r * s) * rgrad_ce(blk, z, y, denom=y.numel()) + v = (z - zs).contiguous() + Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1] + JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1] + z = z + eps * (f - (Jv - JTv)) + if not torch.isfinite(z).all(): + return None, t + if t > P: # demodulate after the transient period + accI = accI + z * s + sI += s * s + return (-(accI / (sI + 1e-12)) / r).detach(), T + """Full holomorphic-EP gradient for block params (same VF readout as ep_step).""" + xin0 = blk.embed(idx).detach() + zs = relax(blk, xin0.clone(), xin0, T1, eps) + res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9) + a, mg = holo_a(blk, zs, xin0, y, N, r, T2, eps) + with torch.enable_grad(): + xin = blk.embed(idx) + f = blk.force(zs.detach(), xin, cg=True) + gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True) + return {id(p): g for p, g in zip(blk.block, gblk)}, res, mg + + +if __name__ == '__main__': + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + torch.manual_seed(0) + B, T, C, H = 16, 64, 128, 4 + blk = EQBlock(C, H, 256, T, attn_mode='thick') + for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')): + with torch.no_grad(): + p.copy_(w.to(dev)) + print("loaded probe weights (300-step BPTT, thick, c=1)", flush=True) + + groups = {'all': blk.block, + 'attn': [blk.WQ, blk.WK, blk.WV, blk.WO], + 'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb], + 'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b], + 'emb': [blk.tok, blk.pos]} + + def cos(ga, gb, ps): + keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None] + if not keep: + return float('nan') + va = torch.cat([ga[id(p)].reshape(-1) for p in keep]) + vb = torch.cat([gb[id(p)].reshape(-1) for p in keep]) + return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item() + + hdr = f"{'estimator':>22} {'res':>9} {'max|dz|':>8} " + " ".join(f"{k:>6}" for k in groups) + for bi in range(3): + idx, y = get_batch('train', B, T) + ref = bptt_step(blk, idx, y, 400, 0.1) + print(("\n" if bi else "") + hdr, flush=True) + for T1 in (150, 400): + gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0) + print(f"{f'plain ep b=.02 T1={T1}':>22} {res:>9.1e} {'--':>8} " + + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True) + gep2, _ = ep_step(blk, idx, y, T1, 20, 0.1, 0.1, 0.0) + print(f"{f'plain ep b=.10 T1={T1}':>22} {res:>9.1e} {'--':>8} " + + " ".join(f"{cos(gep2, ref, ps):>6.3f}" for ps in groups.values()), flush=True) + for (N, r) in ((2, 0.02), (4, 0.05), (4, 0.1), (4, 0.2), (8, 0.2)): + gh, res2, mg = holo_grads(blk, idx, y, T1, 20, 0.1, N, r) + print(f"{f'holo N={N} r={r} T1={T1}':>22} {res2:>9.1e} {mg:>8.2f} " + + " ".join(f"{cos(gh, ref, ps):>6.3f}" for ps in groups.values()), flush=True) |
