diff options
Diffstat (limited to 'scripts')
| -rw-r--r-- | scripts/aep_attention.py | 157 | ||||
| -rw-r--r-- | scripts/aep_characterize.py | 157 | ||||
| -rw-r--r-- | scripts/aep_contractive.py | 52 | ||||
| -rw-r--r-- | scripts/aep_contractive2.py | 41 | ||||
| -rw-r--r-- | scripts/aep_depth.py | 30 | ||||
| -rw-r--r-- | scripts/aep_option1.py | 115 | ||||
| -rw-r--r-- | scripts/aep_projected.py | 125 | ||||
| -rw-r--r-- | scripts/ask_fugu.py | 24 | ||||
| -rw-r--r-- | scripts/bp_transformer.py | 141 | ||||
| -rw-r--r-- | scripts/cet_aep.py | 272 | ||||
| -rw-r--r-- | scripts/cet_mvp.py | 372 | ||||
| -rw-r--r-- | scripts/plot_jr_cmp.py | 20 |
12 files changed, 1506 insertions, 0 deletions
diff --git a/scripts/aep_attention.py b/scripts/aep_attention.py new file mode 100644 index 0000000..868cb05 --- /dev/null +++ b/scripts/aep_attention.py @@ -0,0 +1,157 @@ +""" +CET + AEP: does Asymmetric EP let us train a *non-conservative* attention? + +CET's energy attention is conservative by construction (grad of a scalar LogSumExp +energy -> symmetric Jacobian -> vanilla EP exact). Real transformer attention +softmax(QK^T)V with an INDEPENDENT value V is NOT the gradient of any scalar -> +non-conservative Jacobian -> vanilla EP gives a BIASED gradient. + +AEP (Scellier et al., "EP for Non-Conservative Systems", arXiv:2602.03670) adds a +nudged-phase correction -2 A_J(x*)(x - x*), A_J = (J - J^T)/2 at the free +equilibrium x*. Linearised, this turns the nudged Jacobian J into J^T -- exactly +the adjoint that vanilla EP fails to realise when J != J^T. + +We compare three parameter-gradient estimators vs ground-truth BPTT in two regimes: + cons : F = -x + b + tanh(xS) S^T (manifestly grad of a scalar -> J symmetric) [control] + noncons : F = -x + b + W_O softmax(QK^T/sqrt d)(x W_V) (real attention) [the test] + +Vector-field param-gradient (valid for non-gradient F): + dL/dtheta = < a, dF/dtheta(x*) >, a = (x_{-b} - x_{+b}) / (2 beta). +""" +import torch, torch.nn.functional as F, math +torch.manual_seed(0) +B, N, D, H = 8, 8, 32, 4 +dh = D // H +dev = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def mk_params(regime): + g = torch.Generator(device=dev).manual_seed(1) + s = 1.0 / math.sqrt(D) + if regime == 'noncons': + P = dict(WQ=torch.randn(D, D, generator=g, device=dev) * s, + WK=torch.randn(D, D, generator=g, device=dev) * s, + WV=torch.randn(D, D, generator=g, device=dev) * s, + WO=torch.randn(D, D, generator=g, device=dev) * s, + b=torch.zeros(D, device=dev)) + else: + P = dict(S=torch.randn(D, D, generator=g, device=dev) * s, + b=torch.zeros(D, device=dev)) + for v in P.values(): + v.requires_grad_(True) + return P + + +def _heads(t): + return t.view(B, N, H, dh).transpose(1, 2) + + +def F_noncons(x, P): # real (non-conservative) attention force + q, k, v = _heads(x @ P['WQ']), _heads(x @ P['WK']), _heads(x @ P['WV']) + A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(dh), dim=-1) + o = (A @ v).transpose(1, 2).reshape(B, N, D) @ P['WO'] + return -x + P['b'] + o + + +def F_cons(x, P): # F = -grad E, E = .5|x|^2 -<b,x> -sum logcosh(xS) + return -x + P['b'] + torch.tanh(x @ P['S']) @ P['S'].t() + + +def cost(x, R, tgt): + return 0.5 * ((x.reshape(B, -1) @ R - tgt) ** 2).sum() / B + + +def dcost(x, R, tgt): + x = x.detach().requires_grad_(True) + with torch.enable_grad(): + g, = torch.autograd.grad(cost(x, R, tgt), x) + return g + + +def relax(Ffn, P, x0, steps, eps, extra=None): + x = x0.clone() + for _ in range(steps): + with torch.no_grad(): + f = Ffn(x, P) + if extra is not None: + f = f + extra(x) + x = x + eps * f + return x.detach() + + +def AJ_apply(Ffn, P, xstar, v): # 0.5 (J v - J^T v) at xstar + with torch.enable_grad(): + fx = lambda z: Ffn(z, P) + Jv = torch.autograd.functional.jvp(fx, xstar, v)[1] + JTv = torch.autograd.functional.vjp(fx, xstar, v)[1] + return 0.5 * (Jv - JTv) + + +def ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep): + xstar = relax(Ffn, P, x0, T1, eps) + def nudged(sign): + ex = (lambda x: -2.0 * AJ_apply(Ffn, P, xstar, x - xstar)) if aep else None + fn = lambda x: Ffn(x, P) - sign * beta * dcost(x, R, tgt) + return relax(fn, None, xstar, T2, eps, extra=ex) if False else _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex) + xp, xm = nudged(+1.0), nudged(-1.0) + a = ((xm - xp) / (2.0 * beta)).detach() + xs = xstar.detach() + with torch.enable_grad(): + s = (a * Ffn(xs, P)).sum() + grads = torch.autograd.grad(s, list(P.values()), allow_unused=True) + return grads + + +def _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex): + x = xstar.clone() + for _ in range(T2): + with torch.no_grad(): + f = Ffn(x, P) - sign * beta * dcost(x, R, tgt) + if ex is not None: + f = f + ex(x) + x = x + eps * f + return x.detach() + + +def bptt_grad(Ffn, P, x0, R, tgt, T1, eps): + x = x0.clone() + for _ in range(T1): + x = x + eps * Ffn(x, P) # full graph + return torch.autograd.grad(cost(x, R, tgt), list(P.values()), allow_unused=True) + + +def cosine(ga, gb): + fa = torch.cat([g.flatten() for g in ga]) + fb = torch.cat([g.flatten() for g in gb]) + return F.cosine_similarity(fa, fb, dim=0).item() + + +def run(regime, T1=120, T2=30, eps=0.2, beta=0.02): + P = mk_params(regime) + Ffn = F_cons if regime == 'cons' else F_noncons + g = torch.Generator(device=dev).manual_seed(7) + x0 = torch.randn(B, N, D, generator=g, device=dev) * 0.1 + R = torch.randn(N * D, 16, generator=g, device=dev) / math.sqrt(N * D) + tgt = torch.randn(B, 16, generator=g, device=dev) + + xs = relax(Ffn, P, x0, T1, eps) + res = ((relax(Ffn, P, xs, 1, eps) - xs).norm() / (xs.norm() + 1e-8)).item() + v = torch.randn_like(xs) + aj = AJ_apply(Ffn, P, xs, v) + jv = torch.autograd.functional.jvp(lambda z: Ffn(z, P), xs, v)[1] + asym = (aj.norm() / (jv.norm() + 1e-8)).item() + + gb = bptt_grad(Ffn, P, x0, R, tgt, T1, eps) + gn = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=False) + ga = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=True) + + print(f"\n===== regime={regime} (residual@x*={res:.1e}) =====") + print(f" Jacobian antisymmetry ||A_J v||/||J v|| = {asym:.3f} " + f"({'~conservative' if asym < 0.05 else 'NON-conservative'})") + print(f" cosine(naive_EP, BPTT) = {cosine(gn, gb):+.4f}") + print(f" cosine( AEP , BPTT) = {cosine(ga, gb):+.4f}") + + +if __name__ == '__main__': + run('cons') + run('noncons') diff --git a/scripts/aep_characterize.py b/scripts/aep_characterize.py new file mode 100644 index 0000000..3e642c1 --- /dev/null +++ b/scripts/aep_characterize.py @@ -0,0 +1,157 @@ +""" +Characterize AEP (non-conservative EP) on CET's attention, before porting to the LM. + +Controlled knob: attention scale s in force_z = -dE_rest/dz + s * RealAttn(z). + s=0 -> pure conservative reconstruction (A_J=0; EP exact) + s up -> attention dominates the force -> more non-conservative -> naive EP biased. +Metric: cosine(EP-grad, BPTT-grad) on the attention params {WQ,WK,WV,WO} (the global +cosine is diluted by the dominant conservative params, so we look at attention itself). +The AEP correction is -s*(J_A v) on z, J_A = antisym Jacobian of RealAttn at the free eq. + +Sweeps: (1) s [non-conservativeness], (2) beta [nudge size], (3) T2 [nudge steps], + (4) T1 [free-phase convergence]. Plus: free-eq identical naive vs AEP, and cost. +""" +import argparse, math, time, torch, torch.nn.functional as F +from cet_mvp import make_patch_mask, masked_cost, get_loaders +from cet_aep import CETReal + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +ATTN = ('WQ', 'WK', 'WV', 'WO') + + +def force(model, xbar, z, y, s): + z = z.requires_grad_(True); y = y.requires_grad_(True) + gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True) + return -gz + s * model.real_attn(z), -gy + + +def relax_free(model, xbar, z, y, s, T1, eps): + for _ in range(T1): + with torch.enable_grad(): + fz, fy = force(model, xbar, z, y, s) + fz, fy = fz.detach(), fy.detach() + with torch.no_grad(): + z, y = z + eps * fz, y + eps * fy + return z.detach(), y.detach() + + +def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep): + z, y = zs.clone(), ys.clone() + for _ in range(T2): + with torch.enable_grad(): + fz, fy = force(model, xbar, z, y, s) + fz, fy = fz.detach(), fy.detach() + yy = y.detach().requires_grad_(True) + gy, = torch.autograd.grad(masked_cost(yy, X, M), yy) + fy = fy - sign * beta * gy + if aep: + v = (z - zs).detach() + Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1] + JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1] + fz = fz - s * (Jv - JTv) # -2 * s * 0.5 (J v - J^T v) + with torch.no_grad(): + z, y = z + eps * fz, y + eps * fy + return z.detach(), y.detach() + + +def vf_grad(model, xbar, s, T1, T2, eps, beta, aep): + zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps) + zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep) + zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep) + az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach() + with torch.enable_grad(): + fz, fy = force(model, xbar, zs.detach(), ys.detach(), s) + g = torch.autograd.grad((az * fz).sum() + (ay * fy).sum(), + list(model.parameters()), allow_unused=True) + return zs, g + + +def bptt_grad(model, xbar, s, T1, eps): + z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True) + for _ in range(T1): + fz, fy = force(model, xbar, z, y, s) + z, y = z + eps * fz, y + eps * fy + return torch.autograd.grad(masked_cost(y, X, M) / M.sum(), + list(model.parameters()), allow_unused=True) + + +def attn_cos(g, gb, names): + cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None] + return sum(cs) / len(cs) + + +def global_cos(g, gb): + a = torch.cat([x.flatten() for x in g if x is not None]) + b = torch.cat([x.flatten() for x, y in zip(g, gb) if x is not None and y is not None]) + return F.cosine_similarity(a, b, dim=0).item() + + +def measure(model, names, s, T1, T2, eps, beta): + gb = bptt_grad(model, XBAR, s, T1, eps) + zsn, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=False) + zsa, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=True) + eq_id = (zsn - zsa).norm().item() / (zsn.norm().item() + 1e-9) # free eq identical? + return dict(naive=attn_cos(gn, gb, names), aep=attn_cos(ga, gb, names), + gnaive=global_cos(gn, gb), gaep=global_cos(ga, gb), eq_id=eq_id) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--dataset', default='fashionmnist') + ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1) + ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7) + ap.add_argument('--batch', type=int, default=32) + cfg = ap.parse_args() + torch.manual_seed(0) + model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=64, heads=4, dh=16, mem=128).to(dev) + names = [n for n, _ in model.named_parameters()] + trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset) + global X, M, XBAR + X, _ = next(iter(trl)); X = X.to(dev) + M = make_patch_mask(X.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) + XBAR = X * (1 - M) + + # intrinsic non-conservativeness of the attention map itself + zs, _ = relax_free(model, XBAR, *model.init_state(XBAR), 1.0, 120, 0.2) + v = torch.randn_like(zs) + Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1] + JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1] + print(f"intrinsic attention-map antisymmetry ||A_J v||/||J v|| = " + f"{(0.5*(Jv-JTv)).norm().item()/(Jv.norm().item()+1e-9):.3f}") + + base = dict(T1=120, T2=20, eps=0.2, beta=0.02) + print("\n[1] ATTENTION SCALE s (s=0 conservative -> larger = more non-conservative)") + print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'naive(glob)':>11} {'AEP(glob)':>10} | free-eq id") + for s in [0.25, 0.5, 1.0, 2.0, 4.0, 8.0]: + r = measure(model, names, s, base['T1'], base['T2'], base['eps'], base['beta']) + print(f"{s:6.2f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {r['gnaive']:>11.4f} {r['gaep']:>10.4f} | {r['eq_id']:.1e}") + + print("\n[2] NUDGE STRENGTH beta (s=2, T2=20)") + print(f"{'beta':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") + for beta in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2]: + r = measure(model, names, 2.0, 120, 20, 0.2, beta) + print(f"{beta:6.3f} | {r['naive']:>11.3f} {r['aep']:>10.3f}") + + print("\n[3] NUDGE STEPS T2 (s=2, beta=0.02)") + print(f"{'T2':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") + for T2 in [3, 5, 10, 20, 40]: + r = measure(model, names, 2.0, 120, T2, 0.2, 0.02) + print(f"{T2:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}") + + print("\n[4] FREE-PHASE STEPS T1 (s=2; AEP uses A_J at the free eq)") + print(f"{'T1':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}") + for T1 in [20, 40, 80, 120, 200]: + r = measure(model, names, 2.0, T1, 20, 0.2, 0.02) + print(f"{T1:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}") + + print("\n[5] COST (s=2, T1=120, T2=20)") + t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=False) for _ in range(3)] + torch.cuda.synchronize() if dev == 'cuda' else None; tn = (time.time()-t)/3 + t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=True) for _ in range(3)] + torch.cuda.synchronize() if dev == 'cuda' else None; ta = (time.time()-t)/3 + print(f" naive {tn*1000:.0f} ms/grad AEP {ta*1000:.0f} ms/grad overhead {ta/tn:.2f}x") + + +if __name__ == '__main__': + main() diff --git a/scripts/aep_contractive.py b/scripts/aep_contractive.py new file mode 100644 index 0000000..670feae --- /dev/null +++ b/scripts/aep_contractive.py @@ -0,0 +1,52 @@ +"""F: make REAL attention EP-able by damping it into a contraction (keep it non-conservative). + +Attention term in the force becomes s*(attn(z) - c*z). The -c*z is damping that grows with s, +pushing Re(eig(J_F)) < 0 (a stable fixed point) WITHOUT symmetrizing the Jacobian (the antisymmetric +part is unchanged, so it stays non-conservative -> AEP still needed AND now applicable). + +We sweep (s, c) and report, using the validated projected-adjoint (option 1): + fwd resid : does a stable fixed point exist? (small = yes) + adj cos : projected-adjoint gradient fidelity vs BPTT on attention params +Expected: c=0 breaks at high s (no fixed point, as before); c>=1 keeps resid small + fidelity high. +""" +import torch, aep_option1 as O +from cet_aep import CETReal +from cet_mvp import token_norm, make_patch_mask, get_loaders + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) +names = [n for n, _ in model.named_parameters()] +orig_attn = model.real_attn # original (undamped) attention + +trl, _ = get_loaders(32, dataset='fashionmnist') +X, _ = next(iter(trl)); X = X.to(dev) +M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) +XBAR = X * (1 - M) +O.X, O.M = X, M # masked_cost in option1 uses these globals + + +def set_damp(c): + model.real_attn = orig_attn if c == 0 else (lambda z: orig_attn(z) - c * z) + + +def resid(s, T1, eps=0.2): + zs, ys = O.relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps) + with torch.enable_grad(): + zr, yr = zs.requires_grad_(True), ys.requires_grad_(True) + fz, _ = O.force(model, XBAR, zr, yr, s) + zn = token_norm(zs + eps * fz.detach()) + return ((zn - zs).norm() / (zs.norm() + 1e-9)).item() + + +print("Contractive (damped) non-conservative attention — does it restore a fixed point + EP fidelity?") +print(f"{'s':>5} {'c':>4} | {'fwd resid':>9} {'adj cos(attn)':>13} {'glob':>7}") +for s in [1.0, 2.0, 4.0, 8.0]: + for c in [0.0, 1.0, 2.0]: + set_damp(c) + r = resid(s, 250) + gb = O.bptt_grad(model, XBAR, s, 250, 0.2) + ga = O.adjoint_grad(model, XBAR, s, 250, 0.2, 250) + a, g = O.cosines(ga, gb, names) + print(f"{s:>5.1f} {c:>4.1f} | {r:>9.2e} {a:>13.3f} {g:>7.3f}") + print() diff --git a/scripts/aep_contractive2.py b/scripts/aep_contractive2.py new file mode 100644 index 0000000..f1d38a8 --- /dev/null +++ b/scripts/aep_contractive2.py @@ -0,0 +1,41 @@ +"""F (v2): make real attention EP-able via UNCONSTRAINED dynamics + damping (no projection). + +The projection (C/F-v1) fought radial damping and broke the VF. Drop it: unconstrained AEP +already has clean theory (0.99 fidelity) but diverges at high s for lack of confinement. +Add damping that scales with s: attention term = s*(attn(z) - c*z). Fixed point +z* = [s*attn(z*) + enc]/(4 + s*c) -> attention still sets the direction, but -(4+sc)z makes +it a contraction (so a stable fixed point exists). Small eps needed (the linear part is stiff). + +Reuses aep_characterize's UNCONSTRAINED, AEP-validated machinery; monkeypatches attention to the +damped version. Reports naive vs AEP attention-param cosine vs BPTT, and whether it stayed finite. +""" +import math, torch, aep_characterize as A +from cet_aep import CETReal +from cet_mvp import make_patch_mask, get_loaders + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) +names = [n for n, _ in model.named_parameters()] +orig = model.real_attn +trl, _ = get_loaders(32, dataset='fashionmnist') +X, _ = next(iter(trl)); X = X.to(dev) +M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) +A.X, A.M, A.XBAR = X, M, X * (1 - M) + + +def setc(c): + model.real_attn = orig if c == 0 else (lambda z: orig(z) - c * z) + + +# small eps for the stiff damped linear part; more free steps to converge +EPS, T1, T2, BETA = 0.05, 400, 40, 0.02 +print(f"UNCONSTRAINED + damping, eps={EPS} T1={T1} T2={T2}") +print(f"{'s':>5} {'c':>4} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'finite?':>7}") +for s in [2.0, 4.0, 8.0]: + for c in [0.0, 1.0, 2.0]: + setc(c) + r = A.measure(model, names, s, T1, T2, EPS, BETA) + fin = not (math.isnan(r['aep']) or math.isnan(r['naive'])) + print(f"{s:>5.1f} {c:>4.1f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {str(fin):>7}") + print() diff --git a/scripts/aep_depth.py b/scripts/aep_depth.py new file mode 100644 index 0000000..c202a0c --- /dev/null +++ b/scripts/aep_depth.py @@ -0,0 +1,30 @@ +"""B: does AEP gradient fidelity degrade as the non-conservative attention gets DEEPER? +Stack K residual attention sub-layers (weight-tied) inside the force; measure naive vs +AEP attention-param cosine vs BPTT, at fixed scale s.""" +import torch, aep_characterize as A +from cet_aep import CETReal +from cet_mvp import make_patch_mask, get_loaders + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +torch.manual_seed(0) +model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) +names = [n for n, _ in model.named_parameters()] +trl, _ = get_loaders(32, dataset='fashionmnist') +X, _ = next(iter(trl)); X = X.to(dev) +M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) +A.X, A.M, A.XBAR = X, M, X * (1 - M) + +base = model.real_attn +def deep(K): + def f(z): + h = z + for _ in range(K): + h = h + base(h) + return h - z + return f + +print(f"{'depth K':>8} | {'naive(attn)':>11} {'AEP(attn)':>10}") +for K in [1, 2, 3, 4]: + model.real_attn = deep(K) + r = A.measure(model, names, 1.0, 120, 30, 0.2, 0.02) # s=1, T2=30 (enough per [3]) + print(f"{K:>8} | {r['naive']:>11.3f} {r['aep']:>10.3f}") diff --git a/scripts/aep_option1.py b/scripts/aep_option1.py new file mode 100644 index 0000000..65583d4 --- /dev/null +++ b/scripts/aep_option1.py @@ -0,0 +1,115 @@ +"""option 1: CORRECT gradient for non-conservative attention UNDER the token-norm constraint. + +Implicit differentiation of the projected fixed-point map G(x) = Pi(x + eps F(x)): + adjoint a <- J_G^T a + g , J_G^T = (I + eps J_F^T) Pi'^T , g = dC/dx* + gradient dL/dtheta = eps * < Pi'^T a , dF/dtheta(x*) > + +Built from LOCAL pieces only (this is the projected analogue of EP's nudged adjoint): + Pi'^T : vjp(token_norm, u, .) (the LN/projection Jacobian = LayerNormProjectedSurrogate) + J_F^T : -Hess(E_rest).b (symmetric, via HVP) + s * vjp(real_attn, z*, .) (the non-conservative bit) +Validation: cosine vs BPTT-through-the-projected-relaxation (ground truth). C lost fidelity; this should recover it. +""" +import torch, torch.nn.functional as F, math +from cet_mvp import token_norm, make_patch_mask, masked_cost, get_loaders +from cet_aep import CETReal + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +ATTN = ('WQ', 'WK', 'WV', 'WO') + + +def force(model, xbar, z, y, s, cg=False): + gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=cg) + return -gz + s * model.real_attn(z), -gy + + +def relax_free(model, xbar, z, y, s, T1, eps): + for _ in range(T1): + with torch.enable_grad(): + zr, yr = z.requires_grad_(True), y.requires_grad_(True) + fz, fy = force(model, xbar, zr, yr, s) + fz, fy = fz.detach(), fy.detach() + with torch.no_grad(): + z, y = token_norm(z + eps * fz), y + eps * fy + return z.detach(), y.detach() + + +def adjoint_grad(model, xbar, s, T1, eps, Tadj): + zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps) + # pre-projection point u for Pi' ; cost grad g=(0, dC/dy) + zr, yr = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True) + fz, fy = force(model, xbar, zr, yr, s) + uz = (zs + eps * fz).detach() + yc = ys.detach().requires_grad_(True) + gy_c, = torch.autograd.grad(masked_cost(yc, X, M) / M.sum(), yc) + gy_c = gy_c.detach() + + az, ay = torch.zeros_like(zs), gy_c.clone() # init adjoint at g (cost grad) + for _ in range(Tadj): + bz = torch.autograd.functional.vjp(token_norm, uz, az)[1] # Pi'^T a (z); y identity + by = ay + # J_F^T b = -Hess(E_rest).b + s * vjp(real_attn, zs, bz) + zr2, yr2 = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True) + gz2, gy2 = torch.autograd.grad(model.E_rest(xbar, zr2, yr2), [zr2, yr2], create_graph=True) + hz, hy = torch.autograd.grad((gz2 * bz).sum() + (gy2 * by).sum(), [zr2, yr2]) + av = torch.autograd.functional.vjp(model.real_attn, zs, bz)[1] + JFt_z, JFt_y = -hz + s * av, -hy + az = (bz + eps * JFt_z + torch.zeros_like(zs)).detach() + ay = (by + eps * JFt_y + gy_c).detach() + + # gradient: eps * d/dtheta < Pi'^T a , F(x*, theta) > + bz = torch.autograd.functional.vjp(token_norm, uz, az)[1].detach() + by = ay.detach() + zr3, yr3 = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True) + gz3, gy3 = torch.autograd.grad(model.E_rest(xbar, zr3, yr3), [zr3, yr3], create_graph=True) + Fz = -gz3 + s * model.real_attn(zr3) + Fy = -gy3 + contr = eps * ((bz * Fz).sum() + (by * Fy).sum()) + return torch.autograd.grad(contr, list(model.parameters()), allow_unused=True) + + +def bptt_grad(model, xbar, s, T1, eps): + z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True) + for _ in range(T1): + fz, fy = force(model, xbar, z, y, s, cg=True) + z, y = token_norm(z + eps * fz), y + eps * fy + return torch.autograd.grad(masked_cost(y, X, M) / M.sum(), + list(model.parameters()), allow_unused=True) + + +def cosines(g, gb, names): + c = lambda a, b: F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + at = [c(a, b) for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None] + A = torch.cat([x.flatten() for x in g if x is not None]) + B = torch.cat([y.flatten() for x, y in zip(g, gb) if x is not None and y is not None]) + return (sum(at) / len(at) if at else float('nan')), c(A, B) + + +def main(): + torch.manual_seed(0) + model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) + names = [n for n, _ in model.named_parameters()] + trl, _ = get_loaders(32, dataset='fashionmnist') + global X, M, XBAR + X, _ = next(iter(trl)); X = X.to(dev) + M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) + XBAR = X * (1 - M) + def resid(s, T1, eps=0.2): + zs, ys = relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps) + with torch.enable_grad(): + zr, yr = zs.requires_grad_(True), ys.requires_grad_(True) + fz, fy = force(model, XBAR, zr, yr, s) + zn = token_norm(zs + eps * fz.detach()) + return ((zn - zs).norm() / (zs.norm() + 1e-9)).item() + + print("PROJECTED-ADJOINT (option 1) vs BPTT — is the s>=2 break convergence or no-fixed-point?") + print(f"{'s':>5} {'T1=Tadj':>8} | {'attn cos':>9} {'glob cos':>9} | {'fwd resid':>9}") + for s in [0.5, 1.0, 2.0]: + for it in [120, 400]: + gb = bptt_grad(model, XBAR, s, it, 0.2) + ga = adjoint_grad(model, XBAR, s, it, 0.2, it) + a, g = cosines(ga, gb, names) + print(f"{s:5.1f} {it:>8} | {a:>9.3f} {g:>9.3f} | {resid(s, it):>9.2e}") + + +if __name__ == '__main__': + main() diff --git a/scripts/aep_projected.py b/scripts/aep_projected.py new file mode 100644 index 0000000..af8891e --- /dev/null +++ b/scripts/aep_projected.py @@ -0,0 +1,125 @@ +"""C / option 1: PROJECTED AEP — non-conservative EP on the token-norm constraint manifold. + +Two fixes over the unconstrained version: + (1) STABILITY: relax with the token-norm projection z <- Pi(z + eps F) (bounds z; + this is what made plain CET stable). Lets large-s / deep attention stop diverging. + (2) CORRECT GRADIENT under the constraint: the VF contraction must be projected onto the + TANGENT space of the manifold. The tangent projector at a normalized token z is + P_z(v) = v - mean(v) - mean(v*z) * z + (exactly the local-transformer's LayerNormProjectedSurrogate). Without it the VF + estimator picks up the normal force and collapses (energy-mode cosine ~0.002). + +Param-gradient: dL/dtheta = <a_z, P_z*( dF_z/dtheta )> + <a_y, dF_y/dtheta>, + a = (state_-b - state_+b)/(2 beta). +AEP correction (nudged phase, on z): -s (J v - J^T v) of RealAttn, then projected. +""" +import argparse, math, torch, torch.nn.functional as F +from cet_mvp import token_norm, make_patch_mask, masked_cost, get_loaders +from cet_aep import CETReal + +dev = 'cuda' if torch.cuda.is_available() else 'cpu' +ATTN = ('WQ', 'WK', 'WV', 'WO') + + +def P_tan(z, v): # tangent projection at normalized token z + v = v - v.mean(-1, keepdim=True) + zz = (z * z).mean(-1, keepdim=True).clamp_min(1e-6) + return v - ((v * z).mean(-1, keepdim=True) / zz) * z + + +def force(model, xbar, z, y, s): + z = z.requires_grad_(True); y = y.requires_grad_(True) + gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True) + return -gz + s * model.real_attn(z), -gy + + +def relax_free(model, xbar, z, y, s, T1, eps): + for _ in range(T1): + with torch.enable_grad(): + fz, fy = force(model, xbar, z, y, s); fz, fy = fz.detach(), fy.detach() + with torch.no_grad(): + z = token_norm(z + eps * fz); y = y + eps * fy + return z.detach(), y.detach() + + +def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep): + z, y = zs.clone(), ys.clone() + for _ in range(T2): + with torch.enable_grad(): + fz, fy = force(model, xbar, z, y, s); fz, fy = fz.detach(), fy.detach() + yy = y.detach().requires_grad_(True) + gy, = torch.autograd.grad(masked_cost(yy, X, M), yy) + fy = fy - sign * beta * gy + if aep: + v = (z - zs).detach() + Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1] + JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1] + fz = fz - s * (Jv - JTv) + with torch.no_grad(): + z = token_norm(z + eps * fz); y = y + eps * fy + return z.detach(), y.detach() + + +def vf_grad(model, xbar, s, T1, T2, eps, beta, aep): + zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps) + zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep) + zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep) + az = P_tan(zs, ((zm - zp) / (2 * beta))).detach() # adjoint in tangent space + ay = ((ym - yp) / (2 * beta)).detach() + with torch.enable_grad(): + fz, fy = force(model, xbar, zs.detach(), ys.detach(), s) + s_ = (az * P_tan(zs, fz)).sum() + (ay * fy).sum() # projected contraction + g = torch.autograd.grad(s_, list(model.parameters()), allow_unused=True) + return zs, g + + +def bptt_grad(model, xbar, s, T1, eps): + z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True) + for _ in range(T1): + fz, fy = force(model, xbar, z, y, s) + z = token_norm(z + eps * fz); y = y + eps * fy + return torch.autograd.grad(masked_cost(y, X, M) / M.sum(), + list(model.parameters()), allow_unused=True) + + +def cosines(g, gb, names): + def c(a, b): return F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + at = [c(a, b) for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None] + A = torch.cat([x.flatten() for x in g if x is not None]) + B = torch.cat([y.flatten() for x, y in zip(g, gb) if x is not None and y is not None]) + return (sum(at) / len(at) if at else float('nan')), c(A, B) + + +def measure(model, names, s, T1, T2, eps, beta): + gb = bptt_grad(model, XBAR, s, T1, eps) + _, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, False) + zs, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, True) + an, gng = cosines(gn, gb, names) + aa, gag = cosines(ga, gb, names) + fin = torch.isfinite(zs).all().item() + return an, aa, gng, gag, fin + + +def main(): + torch.manual_seed(0) + model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev) + names = [n for n, _ in model.named_parameters()] + trl, _ = get_loaders(32, dataset='fashionmnist') + global X, M, XBAR + X, _ = next(iter(trl)); X = X.to(dev) + M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev) + XBAR = X * (1 - M) + + print("SANITY s=0 (pure conservative): projected-VF global cosine should be ~1") + _, _, gnaive, _, _ = measure(model, names, 0.0, 120, 20, 0.2, 0.02) + print(f" s=0 global cosine = {gnaive:.4f}\n") + + print("PROJECTED AEP across attention scale s (T1=120 T2=30 beta=0.02)") + print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'finite?':>7} (unproj. broke at s>=4)") + for s in [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]: + an, aa, gn, ga, fin = measure(model, names, s, 120, 30, 0.2, 0.02) + print(f"{s:6.2f} | {an:>11.3f} {aa:>10.3f} | {str(bool(fin)):>7}") + + +if __name__ == '__main__': + main() diff --git a/scripts/ask_fugu.py b/scripts/ask_fugu.py new file mode 100644 index 0000000..b3ffa99 --- /dev/null +++ b/scripts/ask_fugu.py @@ -0,0 +1,24 @@ +import json, urllib.request, urllib.error, os, sys +key = open(os.path.expanduser("~/.codex/sakana.key")).read().strip() +brief = open("/home/yurenh2/ept/PHYSICS_QUESTIONS_FOR_DEEP_REASONING.md").read() +prompt = brief + "\n\n---\nAnswer Q1 through Q7. For each: (a) is it a FUNDAMENTAL obstruction (cite/sketch the no-go) or an ENGINEERING gap (sketch the construction); (b) the physical-realizability verdict (local, forward-only, no backward pass?); (c) the cheapest experiment on our simulator that would falsify your proposed mechanism. Be rigorous, specific, and decisive." +payload = {"model": "fugu-ultra", "input": prompt, "reasoning": {"effort": "xhigh"}} +req = urllib.request.Request("https://api.sakana.ai/v1/responses", + data=json.dumps(payload).encode(), + headers={"Content-Type": "application/json", "Authorization": "Bearer " + key}) +try: + data = json.load(urllib.request.urlopen(req, timeout=3000)) +except urllib.error.HTTPError as e: + print("HTTPError", e.code); print(e.read().decode()[:3000]); sys.exit(1) +except Exception as e: + print("ERR", repr(e)); sys.exit(1) +texts = [] +for item in data.get("output", []): + if item.get("type") == "message": + for c in item.get("content", []): + if c.get("type") in ("output_text", "text"): texts.append(c.get("text", "")) +out = "\n".join(t for t in texts if t) or data.get("output_text", "") or ("RAW:\n" + json.dumps(data)[:4000]) +open("/home/yurenh2/ept/FUGU_PHYSICS_ANSWER.md", "w").write(out) +print("=== USAGE ===", json.dumps(data.get("usage", {}))) +print("=== FUGU-ULTRA ANSWER (saved to FUGU_PHYSICS_ANSWER.md) ===") +print(out) diff --git a/scripts/bp_transformer.py b/scripts/bp_transformer.py new file mode 100644 index 0000000..7c9b543 --- /dev/null +++ b/scripts/bp_transformer.py @@ -0,0 +1,141 @@ +""" +Vanilla backprop Transformer baseline for the SAME masked-image-completion task, +so we can compare against CET-EP and CET-TBPTE on the identical metric +(masked-patch pixel MSE on CIFAR-10, images in [-1,1]). + +Standard recipe: conv patch-embed + learned pos-embed, MAE-style learned mask +token on occluded patches, N standard pre-LN transformer blocks (MHA + FFN), +linear pixel head, MSE loss on masked patches only. Trained with normal Adam/BP. +""" +import argparse, os, time, json, math +import torch, torch.nn as nn, torch.nn.functional as F +from cet_mvp import get_loaders, make_patch_mask # reuse data + masking + + +class Block(nn.Module): + def __init__(self, D, heads, mlp_ratio): + super().__init__() + self.ln1 = nn.LayerNorm(D) + self.attn = nn.MultiheadAttention(D, heads, batch_first=True) + self.ln2 = nn.LayerNorm(D) + self.mlp = nn.Sequential(nn.Linear(D, int(mlp_ratio * D)), nn.GELU(), + nn.Linear(int(mlp_ratio * D), D)) + + def forward(self, x): + h = self.ln1(x) + x = x + self.attn(h, h, h, need_weights=False)[0] + x = x + self.mlp(self.ln2(x)) + return x + + +class BPTransformer(nn.Module): + def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4, + depth=1, mlp_ratio=2.0): + super().__init__() + self.ch, self.patch, self.stride = ch, patch, stride + gh = (img - patch) // stride + 1 + self.gh, self.N, self.pdim = gh, gh * gh, ch * patch * patch + self.embed = nn.Conv2d(ch, D, patch, stride=stride) + self.pos = nn.Parameter(torch.zeros(1, self.N, D)); nn.init.normal_(self.pos, std=0.02) + self.mask_token = nn.Parameter(torch.zeros(1, 1, D)); nn.init.normal_(self.mask_token, std=0.02) + self.blocks = nn.ModuleList([Block(D, heads, mlp_ratio) for _ in range(depth)]) + self.ln = nn.LayerNorm(D) + self.head = nn.Linear(D, self.pdim) + + def patchify(self, x): # (B,C,H,W)->(B,N,pdim) + p, s = self.patch, self.stride + u = x.unfold(2, p, s).unfold(3, p, s) # B,C,gh,gw,p,p + return u.permute(0, 2, 3, 1, 4, 5).reshape(x.size(0), self.N, self.pdim) + + def forward(self, xbar, pm): # pm: (B,N) 1=masked + t = self.embed(xbar).flatten(2).transpose(1, 2) # (B,N,D) + t = torch.where(pm.unsqueeze(-1).bool(), self.mask_token, t) + self.pos + for b in self.blocks: + t = b(t) + return self.head(self.ln(t)) # (B,N,pdim) + + +def patch_mask_bool(B, gh, ratio, device, gen=None): + npatch = gh * gh; nmask = int(round(ratio * npatch)) + idx = torch.rand(B, npatch, device=device, generator=gen).argsort(1) + pm = torch.zeros(B, npatch, device=device) + pm.scatter_(1, idx[:, :nmask], 1.0) + return pm # (B,N) + + +def masked_patch_mse(pred, true, pm): + m = pm.unsqueeze(-1) + return ((pred - true) ** 2 * m).sum() / (m.sum() * pred.size(-1)).clamp_min(1.0) + + +@torch.no_grad() +def evaluate(model, loader, cfg, device, max_batches=100): + model.eval(); tot, n = 0.0, 0 + gen = torch.Generator(device=device).manual_seed(0) + for i, (x, _) in enumerate(loader): + if i >= max_batches: + break + x = x.to(device) + pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device, gen) + M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1) + xbar = x * (1 - M) + pred = model(xbar, pm) + tot += masked_patch_mse(pred, model.patchify(x), pm).item() * x.size(0); n += x.size(0) + model.train(); return tot / n + + +def train(cfg): + device = cfg.device; torch.manual_seed(cfg.seed) + model = BPTransformer(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, + cfg.depth, cfg.mlp_ratio).to(device) + opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min) + trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) + print(f"[bp] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " + f"depth={cfg.depth} D={cfg.D} mlp={cfg.mlp_ratio}", flush=True) + step, t0, run = 0, time.time(), 0.0 + while step < cfg.steps: + for x, _ in trl: + if step >= cfg.steps: + break + x = x.to(device, non_blocking=True) + pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device) + M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1) + xbar = x * (1 - M) + pred = model(xbar, pm) + loss = masked_patch_mse(pred, model.patchify(x), pm) + opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step() + run += loss.item(); step += 1 + if step % cfg.log_every == 0: + print(f"step {step:5d}/{cfg.steps} | train masked-MSE {run/cfg.log_every:.5f} " + f"| {step/(time.time()-t0):.1f} it/s", flush=True); run = 0.0 + if step % cfg.eval_every == 0 or step == cfg.steps: + print(f" >> [eval] step {step} test masked-MSE {evaluate(model, tel, cfg, device, 20):.5f}", flush=True) + final = evaluate(model, tel, cfg, device, 100) + os.makedirs(cfg.out, exist_ok=True) + json.dump({'mode': 'bp_transformer', 'final_test_masked_mse': final, 'steps': cfg.steps, + 'params_K': sum(p.numel() for p in model.parameters()) / 1e3}, + open(os.path.join(cfg.out, 'result_bp_transformer.json'), 'w'), indent=2) + print(f"[bp] DONE final test masked-MSE = {final:.5f}", flush=True) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10') + p.add_argument('--steps', type=int, default=3000); p.add_argument('--batch', type=int, default=128) + p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3) + p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8) + p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4) + p.add_argument('--depth', type=int, default=1); p.add_argument('--mlp_ratio', type=float, default=2.0) + p.add_argument('--mask_ratio', type=float, default=0.5) + p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6) + p.add_argument('--wd', type=float, default=3e-5) + p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=500) + p.add_argument('--seed', type=int, default=0); p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs') + p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') + cfg = p.parse_args() + print('config:', vars(cfg), flush=True); train(cfg) + + +if __name__ == '__main__': + main() diff --git a/scripts/cet_aep.py b/scripts/cet_aep.py new file mode 100644 index 0000000..44e8922 --- /dev/null +++ b/scripts/cet_aep.py @@ -0,0 +1,272 @@ +""" +AEP applied to CET's attention: replace CET's conservative energy-attention E^att +with a REAL (non-conservative) transformer attention inside the CET, and use the +AEP correction so EP still recovers the true gradient. + +CET state = (tokens z, reconstruction y). The conservative part + E_rest = E_enc + E_pos + E_mem + E_dec (scalar -> symmetric Jacobian) +keeps its energy-gradient force. The token force gets its attention term from: + energy mode : -dE^att/dz (conservative; tied value; this is plain CET) + real mode : RealAttn(z) = WO softmax(QK^T/sqrt dh)(z WV) (non-conservative) + +Because only RealAttn is non-conservative, the full-force antisymmetric Jacobian +A_J reduces to the antisymmetric part of dRealAttn/dz alone -> AEP correction is + force_z += -(J_A z~ - J_A^T z~) , z~=z-z* , J_A = dRealAttn/dz at z* +(clean jvp/vjp on RealAttn, no nested autograd). + +We compare parameter-gradient quality vs ground-truth BPTT for: + energy / naive-EP (conservative CET; sanity, should be ~BPTT) + real / naive-EP (non-conservative; expected biased) + real / AEP (non-conservative + correction; expected ~BPTT) +""" +import argparse, math, time, json, os, torch, torch.nn as nn, torch.nn.functional as F +from cet_mvp import token_norm, make_patch_mask, masked_cost, masked_mse, get_loaders + + +class CETReal(nn.Module): + def __init__(self, img=32, ch=3, patch=8, stride=8, D=64, heads=4, dh=16, + mem=128, gamma=0.25): + super().__init__() + self.ch, self.patch, self.stride, self.D = ch, patch, stride, D + self.heads, self.dh, self.gamma = heads, dh, gamma + gh = (img - patch) // stride + 1 + self.gh, self.N = gh, gh * gh + self.damp = 0.0 # contraction damping c: real_attn returns attn(z) - c*z + self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch)) + self.benc = nn.Parameter(torch.zeros(D)) + self.bpos = nn.Parameter(torch.zeros(self.N, D)) + self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch)) + self.bdec = nn.Parameter(torch.zeros(ch)) + self.Wmem = nn.Parameter(torch.empty(D, mem)) + # attention: WQ/WK used by both; WV/WO only by the real (non-conservative) path + self.WQ = nn.Parameter(torch.empty(heads, dh, D)) + self.WK = nn.Parameter(torch.empty(heads, dh, D)) + self.WV = nn.Parameter(torch.empty(heads, dh, D)) + self.WO = nn.Parameter(torch.empty(D, heads * dh)) + nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5 + nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5 + for w in (self.WQ, self.WK, self.WV): + nn.init.normal_(w, std=1.0 / math.sqrt(D)) + nn.init.normal_(self.Wmem, std=0.3 / math.sqrt(D)) # small: keep energy bounded-below + nn.init.normal_(self.WO, std=1.0 / math.sqrt(heads * dh)) + + def encode(self, xbar): + return F.conv2d(xbar, self.Wenc, stride=self.stride).flatten(2).transpose(1, 2) + + def decode_conv(self, y): + return F.conv2d(y, self.Wdec, stride=self.stride).flatten(2).transpose(1, 2) + + def E_rest(self, xbar, z, y): # conservative scalar (no attention) + enc = self.encode(xbar) + E = 2.0 * (z ** 2).sum() - (enc * z).sum() - (z * self.benc).sum() - (z * self.bpos).sum() + proj = torch.einsum('bnd,dm->bnm', z, self.Wmem) + E = E - (F.relu(proj) ** 2).sum() + dc = self.decode_conv(y) + E = E + 0.5 * (y ** 2).sum() - (dc * z).sum() - (y * self.bdec[None, :, None, None]).sum() + return E + + def E_att(self, z): # conservative LogSumExp energy (tied value) + Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) + K = torch.einsum('bnd,hjd->bhnj', z, self.WK) + A = torch.einsum('bhmj,bhnj->bhmn', Q, K) + return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * A, dim=-1).sum() + + def real_attn(self, z): # NON-conservative real attention force + B = z.size(0) + q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) + k = torch.einsum('bnd,hjd->bhnj', z, self.WK) + v = torch.einsum('bnd,hjd->bhnj', z, self.WV) + A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(self.dh), dim=-1) + o = (A @ v).transpose(1, 2).reshape(B, self.N, self.heads * self.dh) + return o @ self.WO.t() - self.damp * z # -c*z: symmetric -> contraction, A_J unchanged + + def force(self, xbar, z, y, mode): + """Return (force_z, force_y). force = -dE/dstate (+ real attention if mode='real').""" + z = z.requires_grad_(True); y = y.requires_grad_(True) + if mode == 'energy': + E = self.E_rest(xbar, z, y) + self.E_att(z) + gz, gy = torch.autograd.grad(E, [z, y], create_graph=True) + return -gz, -gy + else: + E = self.E_rest(xbar, z, y) + gz, gy = torch.autograd.grad(E, [z, y], create_graph=True) + return -gz + self.real_attn(z), -gy + + def init_state(self, xbar): + return token_norm(self.encode(xbar)).detach(), xbar.clone().detach() + + +def relax(model, xbar, z, y, steps, eps, mode, x=None, M=None, beta=0.0, aep=False, zstar=None): + for _ in range(steps): + with torch.enable_grad(): + fz, fy = model.force(xbar, z, y, mode) + fz, fy = fz.detach(), fy.detach() + if beta != 0.0: # nudge on the output y + yy = y.detach().requires_grad_(True) + gy, = torch.autograd.grad(masked_cost(yy, x, M), yy) + fy = fy - beta * gy + if aep: # AEP correction on z (attention block only) + v = (z - zstar).detach() + fa = lambda zz: model.real_attn(zz) + Jv = torch.autograd.functional.jvp(fa, zstar, v)[1] + JTv = torch.autograd.functional.vjp(fa, zstar, v)[1] + corr = Jv - JTv # = 2 * 0.5 (J v - J^T v) + cn, fn = corr.norm(), fz.norm() + 1e-8 # clip so correction can't dominate -> no blow-up + if cn > fn: + corr = corr * (fn / cn) + fz = fz - corr + with torch.no_grad(): + z = z + eps * fz # unconstrained (0.5||z||^2 in E_rest keeps it bounded) + y = y + eps * fy + return z.detach(), y.detach() + + +def vf_param_grad(model, xbar, x, M, mode, T1, T2, eps, beta, aep): + z0, y0 = model.init_state(xbar) + zs, ys = relax(model, xbar, z0, y0, T1, eps, mode) + zp, yp = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, +beta, aep, zs) + zm, ym = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, -beta, aep, zs) + az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach() + with torch.enable_grad(): + fz, fy = model.force(xbar, zs.detach(), ys.detach(), mode) + s = (az * fz).sum() + (ay * fy).sum() + grads = torch.autograd.grad(s, list(model.parameters()), allow_unused=True, retain_graph=False) + return grads + + +def bptt_param_grad(model, xbar, x, M, mode, T1, eps): + z, y = model.init_state(xbar) + z, y = z.requires_grad_(True), y.requires_grad_(True) + for _ in range(T1): + fz, fy = model.force(xbar, z, y, mode) + z = z + eps * fz + y = y + eps * fy + L = masked_cost(y, x, M) / M.sum() + return torch.autograd.grad(L, list(model.parameters()), allow_unused=True) + + +def cos(ga, gb, names): + fa, fb = [], [] + per = {} + for n, a, b in zip(names, ga, gb): + if a is None or b is None: + continue + fa.append(a.flatten()); fb.append(b.flatten()) + per[n] = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + g = F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item() + return g, per + + +def evaluate(model, loader, cfg, dev, mode='real', max_batches=40): + tot, n = 0.0, 0 + gen = torch.Generator(device=dev).manual_seed(0) + for i, (x, _) in enumerate(loader): + if i >= max_batches: + break + x = x.to(dev) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev, gen) + xbar = x * (1 - M) + z, y = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, mode) + tot += masked_mse(y, x, M) * x.size(0); n += x.size(0) + return tot / n + + +def fidelity(cfg, model, dev): + names = [n for n, _ in model.named_parameters()] + trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset) + x, _ = next(iter(trl)); x = x.to(dev) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) + xbar = x * (1 - M) + zs, ys = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, 'real') + v = torch.randn_like(zs) + Jv = torch.autograd.functional.jvp(lambda z: model.real_attn(z), zs, v)[1] + JTv = torch.autograd.functional.vjp(lambda z: model.real_attn(z), zs, v)[1] + asym = (0.5 * (Jv - JTv)).norm().item() / (Jv.norm().item() + 1e-8) + print(f"real-attention Jacobian antisymmetry = {asym:.3f}\n") + for mode, aep, label in [('energy', False, 'energy/naive (sanity)'), + ('real', False, 'real/naive (biased)'), + ('real', True, 'real/AEP (fixed)')]: + gb = bptt_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.eps) + gv = vf_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.T2, cfg.eps, cfg.beta, aep) + g, per = cos(gv, gb, names) + att = " ".join(f"{k}={per[k]:+.3f}" for k in ('WQ', 'WK', 'WV', 'WO') if k in per) + print(f"[{label}] global={g:+.4f} attn: {att}") + + +def train(cfg, model, dev): + tag = 'aep' if cfg.aep else 'naive' + opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.01) + trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) + print(f"[real-attn EP, {tag}] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " + f"T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}", flush=True) + # stay in the stable+faithful regime: cap weight norms (Wmem for bounded-below energy, + # attention WV/WO/WQ/WK so the non-conservative force can't grow into the unstable s>=4 regime) + caps = {n: p.detach().norm().item() * 1.5 for n, p in model.named_parameters() + if n in ('Wmem', 'WQ', 'WK', 'WV', 'WO')} + cap_params = {n: p for n, p in model.named_parameters() if n in caps} + step, t0, best = 0, time.time(), float('inf') + while step < cfg.steps: + for x, _ in trl: + if step >= cfg.steps: + break + x = x.to(dev, non_blocking=True) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev) + xbar = x * (1 - M) + grads = vf_param_grad(model, xbar, x, M, 'real', cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.aep) + opt.zero_grad(set_to_none=True) + bad = False + for p, g in zip(model.parameters(), grads): + if g is None or not torch.isfinite(g).all(): + bad = True; break + p.grad = g + if bad: + print(f" step {step}: non-finite grad, skip", flush=True); step += 1; continue + torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) + opt.step(); sched.step() + with torch.no_grad(): # stay in stable+faithful regime + for n, p in cap_params.items(): + pn = p.norm() + if pn > caps[n]: + p.mul_(caps[n] / pn) + step += 1 + if step % cfg.log_every == 0: + te = evaluate(model, tel, cfg, dev, 'real', 15) + best = min(best, te) + print(f"step {step:4d}/{cfg.steps} | test masked-MSE {te:.5f} (best {best:.5f}) " + f"| {step/(time.time()-t0):.2f} it/s", flush=True) + final = evaluate(model, tel, cfg, dev, 'real', 60) + best = min(best, final) + os.makedirs(cfg.out, exist_ok=True) + json.dump({'tag': tag, 'final_test_masked_mse': final, 'best_test_masked_mse': best, + 'steps': cfg.steps}, open(os.path.join(cfg.out, f'aep_train_{tag}.json'), 'w'), indent=2) + print(f"[real-attn EP, {tag}] DONE final={final:.5f} best={best:.5f}", flush=True) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--cmd', choices=['fidelity', 'train'], default='fidelity') + ap.add_argument('--aep', action='store_true') + ap.add_argument('--damp', type=float, default=0.0) + ap.add_argument('--dataset', default='fashionmnist') + ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1) + ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7) + ap.add_argument('--D', type=int, default=64); ap.add_argument('--heads', type=int, default=4) + ap.add_argument('--dh', type=int, default=16); ap.add_argument('--mem', type=int, default=128) + ap.add_argument('--T1', type=int, default=100); ap.add_argument('--T2', type=int, default=15) + ap.add_argument('--eps', type=float, default=0.2); ap.add_argument('--beta', type=float, default=0.02) + ap.add_argument('--batch', type=int, default=64); ap.add_argument('--steps', type=int, default=1500) + ap.add_argument('--lr', type=float, default=4e-4); ap.add_argument('--wd', type=float, default=1e-4) + ap.add_argument('--log_every', type=int, default=100) + ap.add_argument('--out', default='/home/yurenh2/ept/runs') + cfg = ap.parse_args() + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + torch.manual_seed(0) + model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh, cfg.mem).to(dev) + model.damp = cfg.damp + print('config:', vars(cfg), flush=True) + (train if cfg.cmd == 'train' else fidelity)(cfg, model, dev) + + +if __name__ == '__main__': + main() diff --git a/scripts/cet_mvp.py b/scripts/cet_mvp.py new file mode 100644 index 0000000..deb07d9 --- /dev/null +++ b/scripts/cet_mvp.py @@ -0,0 +1,372 @@ +""" +Convergent Energy Transformer (CET) trained with Equilibrium Propagation. +MVP reproduction of Hoier, Kerjan & Scellier, "Training a Convergent Energy +Transformer with Equilibrium Propagation" (ICLR 2026 AM workshop). + +Energy terms follow the paper's Appendix B exactly: + E = E_enc (eq10, conv-Hopfield) boundary: masked image -> tokens + + E_pos (eq12, per-token bias) + + E_att (eq14, ET LogSumExp attention) <- attention, trained WITHOUT BP under EP + + E_mem (eq15, modern Hopfield memory) <- plays the role of the FFN/MLP + + E_dec (eq16, conv decoder) tokens <-> reconstruction y +Tokens z are normalised (mean 0, std 1) via projection after each PGD step. + +Training modes: + ep : free phase (T1) + two nudged phases (+/-beta, T2). Parameter gradient is the + centered-EP estimator (1/2beta)(dE/dtheta|_{+b} - dE/dtheta|_{-b}). + NO backpropagation through the relaxation dynamics; attention & memory + weights are updated purely from the two equilibria. + tbpte : same model, gradient via truncated backprop through the last T2 relaxation + steps (the paper's BP baseline; BPTE = "backprop through equilibration"). +""" +import argparse, os, time, json, math +import torch, torch.nn as nn, torch.nn.functional as F +import torchvision as tv +from torchvision import transforms + + +# --------------------------------------------------------------------------- # +# Model +# --------------------------------------------------------------------------- # +def token_norm(z, eps=1e-5): + """Project tokens onto the constraint set C: per-token mean 0, std 1 (over D_T).""" + return (z - z.mean(-1, keepdim=True)) / (z.std(-1, unbiased=False, keepdim=True) + eps) + + +class CET(nn.Module): + def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4, dh=32, + mem=256, gamma=0.25): + super().__init__() + self.ch, self.patch, self.stride, self.D = ch, patch, stride, D + self.heads, self.dh, self.gamma = heads, dh, gamma + gh = (img - patch) // stride + 1 + self.gh = gh + self.N = gh * gh # number of tokens / patches + + # Encoder (eq 10): conv kernel mapping patches -> token dim + self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch)) + self.benc = nn.Parameter(torch.zeros(D)) + # Positional bias (eq 12): per (token, dim), NOT shared across tokens + self.bpos = nn.Parameter(torch.zeros(self.N, D)) + # Decoder (eq 16): conv kernel mapping reconstruction -> token dim + self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch)) + self.bdec = nn.Parameter(torch.zeros(ch)) + # Attention (eq 13/14): key/query projections, no value tensor (as in ET) + self.WQ = nn.Parameter(torch.empty(heads, dh, D)) + self.WK = nn.Parameter(torch.empty(heads, dh, D)) + # Memory (eq 15): modern Hopfield memory bank (role of the MLP) + self.Wmem = nn.Parameter(torch.empty(D, mem)) + + nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5 + nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5 + nn.init.normal_(self.WQ, std=1.0 / math.sqrt(D)) + nn.init.normal_(self.WK, std=1.0 / math.sqrt(D)) + nn.init.normal_(self.Wmem, std=1.0 / math.sqrt(D)) + + # -- patch <-> token conv helpers -------------------------------------- # + def encode(self, xbar): # (B,C,H,W) -> (B,N,D) + e = F.conv2d(xbar, self.Wenc, stride=self.stride) + return e.flatten(2).transpose(1, 2) + + def decode_conv(self, y): # (B,C,H,W) -> (B,N,D) + d = F.conv2d(y, self.Wdec, stride=self.stride) + return d.flatten(2).transpose(1, 2) + + # -- energy (per-sample, shape (B,)) ----------------------------------- # + def energy(self, xbar, z, y): + enc = self.encode(xbar) # (B,N,D) + E = 0.5 * (z ** 2).sum((1, 2)) - (enc * z).sum((1, 2)) - (z * self.benc).sum((1, 2)) + E = E - (z * self.bpos).sum((1, 2)) # E_pos (eq12) + # E_att (eq14): per head, score(query m, key n) = <Q_m, K_n>; energy = -1/g sum_m lse_n + Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) # (B,H,N,dh) + K = torch.einsum('bnd,hjd->bhnj', z, self.WK) + A = torch.einsum('bhmj,bhnj->bhmn', Q, K) # (B,H,N,N) + lse = torch.logsumexp(self.gamma * A, dim=-1) # (B,H,N) + E = E - (1.0 / self.gamma) * lse.sum((1, 2)) + # E_mem (eq15): -sum_token sum_mem relu(Wmem^T z)^2 + proj = torch.einsum('bnd,dm->bnm', z, self.Wmem) # (B,N,M) + E = E - (F.relu(proj) ** 2).sum((1, 2)) + # E_dec (eq16): 1/2 y^2 - <conv(y),z> - <y,bdec> + dc = self.decode_conv(y) + E = (E + 0.5 * (y ** 2).sum((1, 2, 3)) - (dc * z).sum((1, 2)) + - (y * self.bdec[None, :, None, None]).sum((1, 2, 3))) + return E + + def init_state(self, xbar): + z = token_norm(self.encode(xbar)).detach() + y = xbar.clone().detach() + return z, y + + # -- one PGD step on F = E + beta*Cost --------------------------------- # + def _grad_step(self, xbar, z, y, eps, x=None, mask=None, beta=0.0, create_graph=False): + z = z.requires_grad_(True) + y = y.requires_grad_(True) + Etot = self.energy(xbar, z, y).sum() + if beta != 0.0: + Etot = Etot + beta * masked_cost(y, x, mask) + gz, gy = torch.autograd.grad(Etot, [z, y], create_graph=create_graph) + z = token_norm(z - eps * gz) + y = y - eps * gy + return z, y + + @torch.no_grad() + def relax(self, xbar, z, y, steps, eps, x=None, mask=None, beta=0.0): + for _ in range(steps): + with torch.enable_grad(): + z, y = self._grad_step(xbar, z, y, eps, x, mask, beta) + z, y = z.detach(), y.detach() + return z, y + + +# --------------------------------------------------------------------------- # +# Cost / masking +# --------------------------------------------------------------------------- # +def masked_cost(y, x, mask): + """0.5 * sum over masked pixels of (y-x)^2, summed over batch (energy units).""" + return 0.5 * (((y - x) ** 2) * mask).sum() + + +def masked_mse(y, x, mask): + """Mean squared error over masked pixels only (reporting metric).""" + num = (((y - x) ** 2) * mask).sum() + den = mask.sum().clamp_min(1.0) + return (num / den).item() + + +def make_patch_mask(B, gh, patch, stride, H, W, ratio, device, gen=None): + """Random per-sample patch mask (1 = masked/occluded). Assumes stride==patch.""" + npatch = gh * gh + nmask = int(round(ratio * npatch)) + noise = torch.rand(B, npatch, device=device, generator=gen) + idx = noise.argsort(dim=1) + pm = torch.zeros(B, npatch, device=device) + pm.scatter_(1, idx[:, :nmask], 1.0) + pm = pm.view(B, gh, gh) + M = pm.repeat_interleave(patch, 1).repeat_interleave(patch, 2) # (B,H,W) + return M.unsqueeze(1) # (B,1,H,W) + + +# --------------------------------------------------------------------------- # +# Gradient estimators +# --------------------------------------------------------------------------- # +def ep_param_grads(model, xbar, x, mask, T1, T2, eps, beta): + """Centered EP. Returns (grads list, free-phase masked MSE for monitoring).""" + z0, y0 = model.init_state(xbar) + z0, y0 = model.relax(xbar, z0, y0, T1, eps) # free phase, beta=0 + free_mse = masked_mse(y0, x, mask) + zp, yp = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=+beta) + zm, ym = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=-beta) + params = [p for p in model.parameters()] + Ep = model.energy(xbar, zp, yp).sum() + gp = torch.autograd.grad(Ep, params) + Em = model.energy(xbar, zm, ym).sum() + gm = torch.autograd.grad(Em, params) + grads = [(a - b) / (2.0 * beta) for a, b in zip(gp, gm)] + return grads, free_mse + + +def tbpte_loss(model, xbar, x, mask, T1, T2, eps): + """Free relaxation (detached) then backprop through last T2 steps. Returns loss.""" + z, y = model.init_state(xbar) + z, y = model.relax(xbar, z, y, T1, eps) # detached + z = z.detach(); y = y.detach() + for _ in range(T2): # last T2 steps WITH graph + z, y = model._grad_step(xbar, z, y, eps, create_graph=True) + return masked_cost(y, x, mask) / mask.sum().clamp_min(1.0), y + + +def bptt_param_grads(model, xbar, x, mask, T1, eps): + """Full backprop through ALL T1 relaxation steps (smoke-test reference only).""" + z, y = model.init_state(xbar) + for _ in range(T1): + z, y = model._grad_step(xbar, z, y, eps, create_graph=True) + loss = masked_cost(y, x, mask) / mask.sum().clamp_min(1.0) + return torch.autograd.grad(loss, [p for p in model.parameters()]) + + +# --------------------------------------------------------------------------- # +# Data +# --------------------------------------------------------------------------- # +def get_loaders(batch, root='/tmp/cet_mvp/data', workers=4, dataset='cifar10'): + if dataset == 'cifar10': + tf = transforms.Compose([transforms.ToTensor(), + transforms.Normalize([0.5] * 3, [0.5] * 3)]) # -> [-1,1] + tr = tv.datasets.CIFAR10(root, train=True, download=True, transform=tf) + te = tv.datasets.CIFAR10(root, train=False, download=True, transform=tf) + elif dataset == 'fashionmnist': + tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + tr = tv.datasets.FashionMNIST(root, train=True, download=True, transform=tf) + te = tv.datasets.FashionMNIST(root, train=False, download=True, transform=tf) + else: + raise ValueError(dataset) + trl = torch.utils.data.DataLoader(tr, batch, shuffle=True, num_workers=workers, + drop_last=True, pin_memory=True) + tel = torch.utils.data.DataLoader(te, batch, shuffle=False, num_workers=workers, + pin_memory=True) + return trl, tel + + +@torch.no_grad() +def evaluate(model, loader, cfg, device, max_batches=20): + model.eval() + tot, n = 0.0, 0 + gen = torch.Generator(device=device).manual_seed(0) + for i, (x, _) in enumerate(loader): + if i >= max_batches: + break + x = x.to(device) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, + x.size(2), x.size(3), cfg.mask_ratio, device, gen) + xbar = x * (1 - M) + z, y = model.init_state(xbar) + z, y = model.relax(xbar, z, y, cfg.T1, cfg.eps) + tot += masked_mse(y, x, M) * x.size(0); n += x.size(0) + model.train() + return tot / n + + +# --------------------------------------------------------------------------- # +# Train +# --------------------------------------------------------------------------- # +def train(cfg): + device = cfg.device + torch.manual_seed(cfg.seed) + model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh, + cfg.mem, cfg.gamma).to(device) + opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min) + trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) + print(f"[{cfg.mode}] model params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " + f"N_tokens={model.N} D={cfg.D} | T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}", + flush=True) + + step, t0, run_loss = 0, time.time(), 0.0 + while step < cfg.steps: + for x, _ in trl: + if step >= cfg.steps: + break + x = x.to(device, non_blocking=True) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, + x.size(2), x.size(3), cfg.mask_ratio, device) + xbar = x * (1 - M) + + opt.zero_grad(set_to_none=True) + if cfg.mode == 'ep': + grads, tr_mse = ep_param_grads(model, xbar, x, M, cfg.T1, cfg.T2, + cfg.eps, cfg.beta) + for p, g in zip(model.parameters(), grads): + p.grad = g + else: # tbpte + loss, _ = tbpte_loss(model, xbar, x, M, cfg.T1, cfg.T2, cfg.eps) + loss.backward() + tr_mse = loss.item() + torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip) + opt.step(); sched.step() + run_loss += tr_mse; step += 1 + + if step % cfg.log_every == 0: + avg = run_loss / cfg.log_every; run_loss = 0.0 + sps = step / (time.time() - t0) + print(f"step {step:5d}/{cfg.steps} | train masked-MSE {avg:.5f} " + f"| lr {sched.get_last_lr()[0]:.2e} | {sps:.1f} it/s", flush=True) + if step % cfg.eval_every == 0 or step == cfg.steps: + te_mse = evaluate(model, tel, cfg, device) + print(f" >> [eval] step {step} test masked-MSE {te_mse:.5f}", flush=True) + + final = evaluate(model, tel, cfg, device, max_batches=100) + os.makedirs(cfg.out, exist_ok=True) + res = {'mode': cfg.mode, 'final_test_masked_mse': final, 'steps': cfg.steps, + 'config': {k: getattr(cfg, k) for k in + ['T1', 'T2', 'eps', 'beta', 'D', 'heads', 'dh', 'mem', + 'patch', 'stride', 'mask_ratio', 'batch', 'lr']}} + with open(os.path.join(cfg.out, f'result_{cfg.mode}.json'), 'w') as f: + json.dump(res, f, indent=2) + torch.save(model.state_dict(), os.path.join(cfg.out, f'cet_{cfg.mode}.pt')) + print(f"[{cfg.mode}] DONE final test masked-MSE = {final:.5f}", flush=True) + return final + + +# --------------------------------------------------------------------------- # +# Smoke test +# --------------------------------------------------------------------------- # +def _residual(model, xbar, z, y, eps): + """Norm of the PGD update (proxy for ||grad E|| at the constrained equilibrium).""" + with torch.enable_grad(): + zn, yn = model._grad_step(xbar, z.clone(), y.clone(), eps) + return ((zn - z).norm() / (z.norm() + 1e-8)).item(), ((yn - y).norm() / (y.norm() + 1e-8)).item() + + +def smoke(cfg): + device = cfg.device + torch.manual_seed(0) + model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=32, heads=2, dh=16, + mem=32, gamma=cfg.gamma).to(device) + x = torch.randn(16, cfg.ch, cfg.img, cfg.img, device=device).clamp(-1, 1) + M = make_patch_mask(16, model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, device) + xbar = x * (1 - M) + T1 = cfg.T1 + print(f"[smoke] T1={T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}") + + # (a) energy decreases during relaxation + z, y = model.init_state(xbar) + print("energy trajectory (free phase):") + es = [] + for t in range(T1 + 1): + e = model.energy(xbar, z, y).mean().item(); es.append(e) + if t % max(1, T1 // 6) == 0: + rz, ry = _residual(model, xbar, z, y, cfg.eps) + print(f" step {t:3d} E={e:12.4f} masked-MSE={masked_mse(y,x,M):.4f}" + f" rel-step |dz|={rz:.2e} |dy|={ry:.2e}") + with torch.enable_grad(): + z, y = model._grad_step(xbar, z, y, cfg.eps) + z, y = z.detach(), y.detach() + mono = all(es[i+1] <= es[i] + 1e-3 for i in range(len(es)-1)) + print(f" monotonic non-increasing: {mono} (start {es[0]:.2f} -> end {es[-1]:.2f})") + print(f" NaN in state: {torch.isnan(z).any().item() or torch.isnan(y).any().item()}") + + # (b) EP gradient vs full-BPTT gradient (key correctness gate) + g_ep, _ = ep_param_grads(model, xbar, x, M, T1, cfg.T2, cfg.eps, beta=cfg.beta) + g_bp = bptt_param_grads(model, xbar, x, M, T1, cfg.eps) + fe = torch.cat([g.flatten() for g in g_ep]) + fb = torch.cat([g.flatten() for g in g_bp]) + cos = F.cosine_similarity(fe, fb, dim=0).item() + names = [n for n, _ in model.named_parameters()] + print(f"\nEP-vs-BPTT gradient cosine (global): {cos:.4f}") + for n, a, b in zip(names, g_ep, g_bp): + c = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + print(f" {n:6s} cos={c:+.3f} |ep|={a.norm():.3e} |bptt|={b.norm():.3e}") + print(f"\nSMOKE {'PASS' if (mono and cos > 0.6) else 'CHECK'} " + f"(want energy monotone & global cos>0.6)") + + +# --------------------------------------------------------------------------- # +def main(): + p = argparse.ArgumentParser() + p.add_argument('--mode', choices=['ep', 'tbpte', 'smoke'], default='smoke') + p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10') + p.add_argument('--steps', type=int, default=4000) + p.add_argument('--batch', type=int, default=128) + p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3) + p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8) + p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4) + p.add_argument('--dh', type=int, default=32); p.add_argument('--mem', type=int, default=256) + p.add_argument('--gamma', type=float, default=0.25) + p.add_argument('--T1', type=int, default=30); p.add_argument('--T2', type=int, default=5) + p.add_argument('--eps', type=float, default=0.5); p.add_argument('--beta', type=float, default=0.1) + p.add_argument('--mask_ratio', type=float, default=0.5) + p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6) + p.add_argument('--wd', type=float, default=3e-5); p.add_argument('--clip', type=float, default=10.0) + p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=1000) + p.add_argument('--seed', type=int, default=0) + p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs') + p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') + cfg = p.parse_args() + print('config:', vars(cfg), flush=True) + if cfg.mode == 'smoke': + smoke(cfg) + else: + train(cfg) + + +if __name__ == '__main__': + main() diff --git a/scripts/plot_jr_cmp.py b/scripts/plot_jr_cmp.py new file mode 100644 index 0000000..2f8af95 --- /dev/null +++ b/scripts/plot_jr_cmp.py @@ -0,0 +1,20 @@ +import re, matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt +def parse(p): + s,c,e=[],[],[] + for ln in open(p): + m=re.search(r'step\s+(\d+)/\d+ \| val CE ([\d.]+) ema=([\d.]+)', ln) + if m: s.append(int(m.group(1))); c.append(float(m.group(2))); e.append(float(m.group(3))) + return s,c,e +fz=parse('ep_run/runs/ep_resreg_warm.log'); ad=parse('ep_run/runs/ep_rr_ajr.log') +fig,ax=plt.subplots(figsize=(11,5.5),dpi=140) +for (s,c,e),col,lab in [(fz,'#4363d8','frozen jr=0.1'),(ad,'#e6194B','adaptive jacreg (jr_max=16)')]: + ax.plot(s,c,color=col,lw=0.7,alpha=0.25) + ax.plot(s,e,color=col,lw=2.3,label=lab+f' (last ema {e[-1]:.4f})') +ax.axhline(2.0,color='gray',ls='--',lw=1,alpha=0.7,label='val CE = 2.0') +ax.set_xlabel('training step'); ax.set_ylabel('val CE'); ax.set_ylim(1.92,2.6) +ax.set_title('Frozen jr vs adaptive jacreg — C512 EP, same s2000 warm-start (only jr differs)\nbold = ema, faint = raw per-log val CE', fontsize=10.5) +ax.legend(fontsize=9.5, loc='upper right'); ax.grid(alpha=0.2) +fig.tight_layout(); fig.savefig('frozen_vs_adaptive.png',dpi=140,bbox_inches='tight') +print(f"frozen: {len(fz[0])} pts, steps {fz[0][0]}-{fz[0][-1]}, best ema {min(fz[2]):.4f}") +print(f"adaptive: {len(ad[0])} pts, steps {ad[0][0]}-{ad[0][-1]}, best ema {min(ad[2]):.4f}") +print("saved") |
