summaryrefslogtreecommitdiff
path: root/ep_run/holo_ep.py
blob: 31054e4125a9fec3132c3404593d60fbfe6fd2a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""Holomorphic EP (Laborieux & Zenke 2022) for the non-conservative thick block.

Plain EP estimates a = -dz*/dbeta with a 2-point real centered difference: bias O(beta^2) forces
beta small (0.02), and the estimator noise scales like (equilibration error)/beta. Holomorphic EP
evaluates the nudged equilibrium at N points on a CIRCLE |beta|=r in the complex plane and reads
-dz*/dbeta off the discrete Cauchy/Fourier formula  a = -Re[(1/(N r)) sum_k e^{-i phi_k} z*(r e^{i phi_k})]:
bias O(r^N) instead of O(r^2)  ->  r can be 5-10x larger at equal bias  ->  the 1/beta noise
amplification drops by the same factor. Requires the force holomorphically extended to complex
state: manual LN (non-conjugate variance), softmax (exp ratio), GELU (tanh form, entire).
The AEP correction carries over unchanged: it is linear in (z - z*) with REAL coefficients, so it
preserves holomorphy in beta; apply it to real and imaginary parts separately.
NOTE: no g-clamp and no corr-clip inside the holomorphic nudge (clamps are non-analytic and would
destroy the O(r^N) bias property); we monitor max|z-z*| instead."""
import math, torch, torch.nn.functional as F
from lt_ep_train import EQBlock, get_batch, ep_step, bptt_step, relax

CDT = torch.complex64


def cln(z, g, b, eps=1e-5):                       # holomorphic LayerNorm: NON-conjugate variance
    mu = z.mean(-1, keepdim=True)
    v = ((z - mu) ** 2).mean(-1, keepdim=True)    # analytic continuation of the real LN
    return (z - mu) / torch.sqrt(v + eps) * g + b


def csoftmax_masked(a, mask):                     # holomorphic causal softmax via exp ratio
    c = a.real.amax(-1, keepdim=True)             # constant row shift cancels exactly in the ratio
    w = torch.exp(a - c) * mask                   # masked entries -> exact 0
    return w / w.sum(-1, keepdim=True)


def cgelu(z):                                     # tanh-form GELU: entire function
    return 0.5 * z * (1.0 + torch.tanh(0.7978845608028654 * (z + 0.044715 * z ** 3)))


def cforce(blk, z, xin):                          # holomorphic extension of the thick force
    C, H, dh, T = blk.C, blk.H, blk.dh, blk.T
    B = z.size(0)
    h1 = cln(z, blk.ln1g.to(CDT), blk.ln1b.to(CDT))
    h2 = cln(z, blk.ln2g.to(CDT), blk.ln2b.to(CDT))
    q = (h1 @ blk.WQ.to(CDT)).view(B, T, H, dh).transpose(1, 2)
    k = (h1 @ blk.WK.to(CDT)).view(B, T, H, dh).transpose(1, 2)
    v = (h1 @ blk.WV.to(CDT)).view(B, T, H, dh).transpose(1, 2)
    if getattr(blk, 'qknorm', False):                # match attn()'s q/k RMSNorm (holomorphic: non-conjugate q^2)
        q = q * (q.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5)
        k = k * (k.pow(2).mean(-1, keepdim=True) + 1e-6).pow(-0.5)
    a = (q @ k.transpose(-2, -1)) / math.sqrt(dh)
    p = csoftmax_masked(a, blk.cmask.to(CDT))
    att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO.to(CDT)
    ff = cgelu(h2 @ blk.fc.to(CDT) + blk.fcb.to(CDT)) @ blk.pj.to(CDT) + blk.pjb.to(CDT)
    return -(z - xin) + att + ff - blk.c * z


def cgrad_ce(blk, z, y):                          # holomorphic dCE/dz = (softmax(z Wh) - Y) Wh^T / NT
    logits = z @ blk.Wh.to(CDT)
    c = logits.real.amax(-1, keepdim=True)
    w = torch.exp(logits - c)
    p = w / w.sum(-1, keepdim=True)
    Y = F.one_hot(y, p.size(-1)).to(CDT)
    return (p - Y) @ blk.Wh.t().to(CDT) / y.numel()


def holo_a(blk, zs, xin, y, N, r, T2, eps, corr_on=True):
    """Nudged phases at beta_k = r e^{2 pi i k / N}; returns (a, max|z - z*|) with
    a = -Re[(1/(N r)) sum_k e^{-i phi_k} (z_k - z*)]  ~  -dz*/dbeta + O(r^N)."""
    zsc, xc = zs.to(CDT), xin.to(CDT)
    acc = torch.zeros_like(zsc)
    mg = 0.0
    for kk in range(N):
        ph = complex(math.cos(2 * math.pi * kk / N), math.sin(2 * math.pi * kk / N))
        beta = r * ph
        z = zsc.clone()
        for _ in range(T2):
            with torch.no_grad():
                f = cforce(blk, z, xc) - beta * cgrad_ce(blk, z, y)
                if corr_on:                       # AEP: J -> J^T, linear & real -> holomorphy kept
                    v = z - zsc
                    Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j
                    JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j
                    if v.imag.abs().max() > 1e-9:                 # real-axis phases skip the imag solves
                        Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1]
                        JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1]
                    f = f - (Jv - JTv)
                z = z + eps * f
        acc = acc + torch.conj(torch.tensor(ph, device=z.device)) * (z - zsc)
        mg = max(mg, (z - zsc).abs().max().item())
    a = -(acc / (N * r)).real
    return a.detach(), mg


def holo_a_select(blk, zs, xin, y, N, r, T2max, eps, K=10, exit_mult=5.0, corr_every=1):
    """Adaptive-T2 by hindsight selection: run nudged phases in lockstep to T2max, snapshot the
    contrast a_t every K steps, return the snapshot with the smallest increment (most settled).
    Never worse than short fixed T2 (the settled snapshot exists early too); captures the long-T2
    win (cos up to ~0.99) when the nudged dynamics are stable; early-exits only on clear blowup —
    judging by increments of the QUANTITY OF INTEREST, not step sizes, so non-normal transient
    growth cannot trigger a premature stop."""
    zsc, xc = zs.to(CDT), xin.to(CDT)
    ph = [complex(math.cos(2 * math.pi * k / N), math.sin(2 * math.pi * k / N)) for k in range(N)]
    Z = [zsc.clone() for _ in range(N)]
    corr = [None] * N
    a_prev = a_best = None
    inc_min, t_best = float('inf'), 0
    for t in range(1, T2max + 1):
        for k in range(N):
            with torch.no_grad():
                f = cforce(blk, Z[k], xc) - (r * ph[k]) * cgrad_ce(blk, Z[k], y)
                if corr[k] is None or (t - 1) % corr_every == 0:   # v moves ~eps/step: stale corr is cheap
                    v = Z[k] - zsc
                    Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v.real.contiguous())[1] + 0j
                    JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v.real.contiguous())[1] + 0j
                    if v.imag.abs().max() > 1e-9:
                        Jv = Jv + 1j * torch.autograd.functional.jvp(blk.nc_force, zs, v.imag.contiguous())[1]
                        JTv = JTv + 1j * torch.autograd.functional.vjp(blk.nc_force, zs, v.imag.contiguous())[1]
                    corr[k] = Jv - JTv
                Z[k] = Z[k] + eps * (f - corr[k])
        if t % K == 0 or t == T2max:
            acc = sum(torch.conj(torch.tensor(p, device=zs.device)) * (zk - zsc) for p, zk in zip(ph, Z))
            a_t = -(acc / (N * r)).real
            if not torch.isfinite(a_t).all():
                break
            if a_prev is not None:
                inc = (a_t - a_prev).norm().item()
                if inc < inc_min:
                    inc_min, a_best, t_best = inc, a_t, t
                elif inc > exit_mult * inc_min and t >= 3 * K:
                    break
            a_prev = a_t
    if a_best is None:
        a_best, t_best = a_prev, T2max
    return a_best.detach(), t_best


def rforce(blk, z, xin):                          # real-axis twin of cforce (tanh-gelu, clamp-free)
    C, H, dh, T = blk.C, blk.H, blk.dh, blk.T
    B = z.size(0)
    h1 = F.layer_norm(z, (C,), blk.ln1g, blk.ln1b)
    h2 = F.layer_norm(z, (C,), blk.ln2g, blk.ln2b)
    q = (h1 @ blk.WQ).view(B, T, H, dh).transpose(1, 2)
    k = (h1 @ blk.WK).view(B, T, H, dh).transpose(1, 2)
    v = (h1 @ blk.WV).view(B, T, H, dh).transpose(1, 2)
    if getattr(blk, 'qknorm', False):                     # match attn()'s q/k RMSNorm in the nudge force
        q = q * torch.rsqrt(q.pow(2).mean(-1, keepdim=True) + 1e-6)
        k = k * torch.rsqrt(k.pow(2).mean(-1, keepdim=True) + 1e-6)
    a = (q @ k.transpose(-2, -1)) / math.sqrt(dh)
    p = torch.softmax(a.masked_fill(~blk.cmask, float('-inf')), -1)
    att = (p @ v).transpose(1, 2).reshape(B, T, C) @ blk.WO
    ff = cgelu(h2 @ blk.fc + blk.fcb) @ blk.pj + blk.pjb
    nc = att + ff
    if getattr(blk, 'fnoise', 0.0) > 0:
        nc = nc * (1 + blk.fnoise * torch.randn_like(nc))
    return -(z - xin) + nc - blk.c * z


def rgrad_ce(blk, z, y, denom=None):
    p = torch.softmax(z @ blk.Wh, -1)
    return (p - F.one_hot(y, p.size(-1)).to(z.dtype)) @ blk.Wh.t() / (denom or y.numel())


def holo_a_select2(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0, li=0):
    """li>0 enables LOCK-IN INTEGRATION mode for noisy (hardware) physics: run the full T2max,
    EMA the contrast a_t every step with time-constant li — the homodyne integrator that divides
    persistent per-pass noise by sqrt(window). The hindsight-selection mode (li=0) is for clean
    physics, where a single most-settled snapshot is optimal."""
    """N=2 production fast path — mathematically identical to holo_a_select(N=2): both phases are
    real, so run them PHASE-BATCHED (stack +r/-r along batch) with real tensors and torch.func
    forward-mode jvp. Halves autograd calls and skips complex arithmetic."""
    import torch.func as tf
    B = zs.size(0)
    Z = torch.cat([zs, zs], 0)                                  # [+r phase | -r phase]
    X2 = torch.cat([xin, xin], 0)
    y2 = torch.cat([y, y], 0)
    sg = torch.cat([torch.full((B, 1, 1), r, device=zs.device),
                    torch.full((B, 1, 1), -r, device=zs.device)], 0)
    zs2 = torch.cat([zs, zs], 0)
    fnc = lambda zz: blk.nc_force(zz)
    a_prev = a_best = a_ema = None
    inc_min, t_best = float('inf'), 0
    for t in range(1, T2max + 1):
        with torch.no_grad():
            f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel())   # CE mean over the ORIGINAL batch
            v = (Z - zs2).contiguous()
            _, Jv = tf.jvp(fnc, (zs2,), (v,))
            JTv = tf.vjp(fnc, zs2)[1](v)[0]
            Z = Z + eps * (f - (Jv - JTv))
        if li > 0:                                              # lock-in integration (noisy physics)
            a_t = (Z[B:] - Z[:B]) / (2 * r)
            if not torch.isfinite(a_t).all():
                break
            if t > T2max // 3:                                  # let phases develop, then integrate
                a_ema = a_t if a_ema is None else a_ema + (a_t - a_ema) / li
            continue
        if t % K == 0 or t == T2max:
            a_t = (Z[B:] - Z[:B]) / (2 * r)                     # (z_- - z_+)/2r
            if not torch.isfinite(a_t).all():
                break
            if a_prev is not None:
                inc = (a_t - a_prev).norm().item()
                if inc < inc_min:
                    inc_min, a_best, t_best = inc, a_t, t
                elif inc > exit_mult * inc_min and t >= 3 * K:
                    break
            a_prev = a_t
    if li > 0:
        if a_ema is None:
            a_ema = (Z[B:] - Z[:B]) / (2 * r)
        return a_ema.detach(), T2max
    if a_best is None:
        a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r)
        t_best = T2max
    return a_best.detach(), t_best


def holo_a_track(blk, zs, xin, y, r, T2max, eps, K=10, exit_mult=5.0):
    """Common-mode-tracking AEP: linearize the antisymmetric correction at the instantaneous
    common mode of the two phases — exact transposed differential dynamics, loose-tolerant,
    no compounding linearization error."""
    import torch.func as tf
    B = zs.size(0)
    Z = torch.cat([zs, zs], 0)
    X2 = torch.cat([xin, xin], 0)
    y2 = torch.cat([y, y], 0)
    sg = torch.cat([torch.full((B,1,1), r, device=zs.device), torch.full((B,1,1), -r, device=zs.device)], 0)
    fnc = lambda zz: blk.nc_force(zz)
    a_prev = a_best = None
    inc_min, t_best = float('inf'), 0
    zs2a = torch.cat([zs, zs], 0)
    kappa = getattr(blk, 'nbrake', 0.0)
    for t in range(1, T2max + 1):
        with torch.no_grad():
            zbar = 0.5 * (Z[:B] + Z[B:])
            zb2 = torch.cat([zbar, zbar], 0)
            f = rforce(blk, Z, X2) - sg * rgrad_ce(blk, Z, y2, denom=y.numel())
            if kappa > 0:                          # measurement brake: Tikhonov-regularized adjoint
                f = f - kappa * (Z - zs2a)
            v = (Z - zb2).contiguous()
            _, Jv = tf.jvp(fnc, (zb2,), (v,))
            JTv = tf.vjp(fnc, zb2)[1](v)[0]
            Z = Z + eps * (f - (Jv - JTv))
        if t % K == 0 or t == T2max:
            a_t = (Z[B:] - Z[:B]) / (2 * r)
            if not torch.isfinite(a_t).all():
                break
            if a_prev is not None:
                inc = (a_t - a_prev).norm().item()
                if inc < inc_min:
                    inc_min, a_best, t_best = inc, a_t, t
                elif inc > exit_mult * inc_min and t >= 3 * K:
                    break
            a_prev = a_t
    if a_best is None:
        a_best = a_prev if a_prev is not None else (Z[B:] - Z[:B]) / (2 * r)
        t_best = T2max
    return a_best.detach(), t_best


def holo_a_lockin(blk, zs, xin, y, r, P, ncyc, eps):
    """True oscillatory EP / lock-in estimator (Laborieux–Zenke taken literally) — the
    noisy-physics form: ONE trajectory, sinusoidal nudge beta(t)=r·sin(2πt/P), in-phase
    demodulation over ncyc periods (first period discarded as transient). Single-trajectory =>
    common-mode noise cancels in the quadrature; v=z−z* stays O(r·response) so the AEP
    linearization never leaves its window; noise admitted only in the demodulation band."""
    z = zs.clone()
    accI = torch.zeros_like(zs)
    sI = 0.0
    T = P * (ncyc + 1)
    for t in range(1, T + 1):
        s = math.sin(2 * math.pi * t / P)
        with torch.no_grad():
            f = rforce(blk, z, xin) - (r * s) * rgrad_ce(blk, z, y, denom=y.numel())
            v = (z - zs).contiguous()
            Jv = torch.autograd.functional.jvp(blk.nc_force, zs, v)[1]
            JTv = torch.autograd.functional.vjp(blk.nc_force, zs, v)[1]
            z = z + eps * (f - (Jv - JTv))
            if not torch.isfinite(z).all():
                return None, t
        if t > P:                                  # demodulate after the transient period
            accI = accI + z * s
            sI += s * s
    return (-(accI / (sI + 1e-12)) / r).detach(), T
    """Full holomorphic-EP gradient for block params (same VF readout as ep_step)."""
    xin0 = blk.embed(idx).detach()
    zs = relax(blk, xin0.clone(), xin0, T1, eps)
    res = (relax(blk, zs, xin0, 1, eps) - zs).norm().item() / (zs.norm().item() + 1e-9)
    a, mg = holo_a(blk, zs, xin0, y, N, r, T2, eps)
    with torch.enable_grad():
        xin = blk.embed(idx)
        f = blk.force(zs.detach(), xin, cg=True)
        gblk = torch.autograd.grad((a * f).sum(), blk.block, allow_unused=True)
    return {id(p): g for p, g in zip(blk.block, gblk)}, res, mg


if __name__ == '__main__':
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.manual_seed(0)
    B, T, C, H = 16, 64, 128, 4
    blk = EQBlock(C, H, 256, T, attn_mode='thick')
    for p, w in zip(blk.allp, torch.load('/tmp/lt_ep/probe_w.pt')):
        with torch.no_grad():
            p.copy_(w.to(dev))
    print("loaded probe weights (300-step BPTT, thick, c=1)", flush=True)

    groups = {'all': blk.block,
              'attn': [blk.WQ, blk.WK, blk.WV, blk.WO],
              'ffn': [blk.fc, blk.fcb, blk.pj, blk.pjb],
              'ln': [blk.ln1g, blk.ln1b, blk.ln2g, blk.ln2b],
              'emb': [blk.tok, blk.pos]}

    def cos(ga, gb, ps):
        keep = [p for p in ps if ga.get(id(p)) is not None and gb.get(id(p)) is not None]
        if not keep:
            return float('nan')
        va = torch.cat([ga[id(p)].reshape(-1) for p in keep])
        vb = torch.cat([gb[id(p)].reshape(-1) for p in keep])
        return (va @ vb / (va.norm() * vb.norm() + 1e-12)).item()

    hdr = f"{'estimator':>22} {'res':>9} {'max|dz|':>8} " + " ".join(f"{k:>6}" for k in groups)
    for bi in range(3):
        idx, y = get_batch('train', B, T)
        ref = bptt_step(blk, idx, y, 400, 0.1)
        print(("\n" if bi else "") + hdr, flush=True)
        for T1 in (150, 400):
            gep, res = ep_step(blk, idx, y, T1, 20, 0.1, 0.02, 0.0)
            print(f"{f'plain ep b=.02 T1={T1}':>22} {res:>9.1e} {'--':>8} "
                  + " ".join(f"{cos(gep, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
            gep2, _ = ep_step(blk, idx, y, T1, 20, 0.1, 0.1, 0.0)
            print(f"{f'plain ep b=.10 T1={T1}':>22} {res:>9.1e} {'--':>8} "
                  + " ".join(f"{cos(gep2, ref, ps):>6.3f}" for ps in groups.values()), flush=True)
            for (N, r) in ((2, 0.02), (4, 0.05), (4, 0.1), (4, 0.2), (8, 0.2)):
                gh, res2, mg = holo_grads(blk, idx, y, T1, 20, 0.1, N, r)
                print(f"{f'holo N={N} r={r} T1={T1}':>22} {res2:>9.1e} {mg:>8.2f} "
                      + " ".join(f"{cos(gh, ref, ps):>6.3f}" for ps in groups.values()), flush=True)