summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/aep_attention.py157
-rw-r--r--scripts/aep_characterize.py157
-rw-r--r--scripts/aep_contractive.py52
-rw-r--r--scripts/aep_contractive2.py41
-rw-r--r--scripts/aep_depth.py30
-rw-r--r--scripts/aep_option1.py115
-rw-r--r--scripts/aep_projected.py125
-rw-r--r--scripts/ask_fugu.py24
-rw-r--r--scripts/bp_transformer.py141
-rw-r--r--scripts/cet_aep.py272
-rw-r--r--scripts/cet_mvp.py372
-rw-r--r--scripts/plot_jr_cmp.py20
12 files changed, 1506 insertions, 0 deletions
diff --git a/scripts/aep_attention.py b/scripts/aep_attention.py
new file mode 100644
index 0000000..868cb05
--- /dev/null
+++ b/scripts/aep_attention.py
@@ -0,0 +1,157 @@
+"""
+CET + AEP: does Asymmetric EP let us train a *non-conservative* attention?
+
+CET's energy attention is conservative by construction (grad of a scalar LogSumExp
+energy -> symmetric Jacobian -> vanilla EP exact). Real transformer attention
+softmax(QK^T)V with an INDEPENDENT value V is NOT the gradient of any scalar ->
+non-conservative Jacobian -> vanilla EP gives a BIASED gradient.
+
+AEP (Scellier et al., "EP for Non-Conservative Systems", arXiv:2602.03670) adds a
+nudged-phase correction -2 A_J(x*)(x - x*), A_J = (J - J^T)/2 at the free
+equilibrium x*. Linearised, this turns the nudged Jacobian J into J^T -- exactly
+the adjoint that vanilla EP fails to realise when J != J^T.
+
+We compare three parameter-gradient estimators vs ground-truth BPTT in two regimes:
+ cons : F = -x + b + tanh(xS) S^T (manifestly grad of a scalar -> J symmetric) [control]
+ noncons : F = -x + b + W_O softmax(QK^T/sqrt d)(x W_V) (real attention) [the test]
+
+Vector-field param-gradient (valid for non-gradient F):
+ dL/dtheta = < a, dF/dtheta(x*) >, a = (x_{-b} - x_{+b}) / (2 beta).
+"""
+import torch, torch.nn.functional as F, math
+torch.manual_seed(0)
+B, N, D, H = 8, 8, 32, 4
+dh = D // H
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+
+def mk_params(regime):
+ g = torch.Generator(device=dev).manual_seed(1)
+ s = 1.0 / math.sqrt(D)
+ if regime == 'noncons':
+ P = dict(WQ=torch.randn(D, D, generator=g, device=dev) * s,
+ WK=torch.randn(D, D, generator=g, device=dev) * s,
+ WV=torch.randn(D, D, generator=g, device=dev) * s,
+ WO=torch.randn(D, D, generator=g, device=dev) * s,
+ b=torch.zeros(D, device=dev))
+ else:
+ P = dict(S=torch.randn(D, D, generator=g, device=dev) * s,
+ b=torch.zeros(D, device=dev))
+ for v in P.values():
+ v.requires_grad_(True)
+ return P
+
+
+def _heads(t):
+ return t.view(B, N, H, dh).transpose(1, 2)
+
+
+def F_noncons(x, P): # real (non-conservative) attention force
+ q, k, v = _heads(x @ P['WQ']), _heads(x @ P['WK']), _heads(x @ P['WV'])
+ A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(dh), dim=-1)
+ o = (A @ v).transpose(1, 2).reshape(B, N, D) @ P['WO']
+ return -x + P['b'] + o
+
+
+def F_cons(x, P): # F = -grad E, E = .5|x|^2 -<b,x> -sum logcosh(xS)
+ return -x + P['b'] + torch.tanh(x @ P['S']) @ P['S'].t()
+
+
+def cost(x, R, tgt):
+ return 0.5 * ((x.reshape(B, -1) @ R - tgt) ** 2).sum() / B
+
+
+def dcost(x, R, tgt):
+ x = x.detach().requires_grad_(True)
+ with torch.enable_grad():
+ g, = torch.autograd.grad(cost(x, R, tgt), x)
+ return g
+
+
+def relax(Ffn, P, x0, steps, eps, extra=None):
+ x = x0.clone()
+ for _ in range(steps):
+ with torch.no_grad():
+ f = Ffn(x, P)
+ if extra is not None:
+ f = f + extra(x)
+ x = x + eps * f
+ return x.detach()
+
+
+def AJ_apply(Ffn, P, xstar, v): # 0.5 (J v - J^T v) at xstar
+ with torch.enable_grad():
+ fx = lambda z: Ffn(z, P)
+ Jv = torch.autograd.functional.jvp(fx, xstar, v)[1]
+ JTv = torch.autograd.functional.vjp(fx, xstar, v)[1]
+ return 0.5 * (Jv - JTv)
+
+
+def ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep):
+ xstar = relax(Ffn, P, x0, T1, eps)
+ def nudged(sign):
+ ex = (lambda x: -2.0 * AJ_apply(Ffn, P, xstar, x - xstar)) if aep else None
+ fn = lambda x: Ffn(x, P) - sign * beta * dcost(x, R, tgt)
+ return relax(fn, None, xstar, T2, eps, extra=ex) if False else _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex)
+ xp, xm = nudged(+1.0), nudged(-1.0)
+ a = ((xm - xp) / (2.0 * beta)).detach()
+ xs = xstar.detach()
+ with torch.enable_grad():
+ s = (a * Ffn(xs, P)).sum()
+ grads = torch.autograd.grad(s, list(P.values()), allow_unused=True)
+ return grads
+
+
+def _nud(Ffn, P, xstar, R, tgt, T2, eps, sign, beta, ex):
+ x = xstar.clone()
+ for _ in range(T2):
+ with torch.no_grad():
+ f = Ffn(x, P) - sign * beta * dcost(x, R, tgt)
+ if ex is not None:
+ f = f + ex(x)
+ x = x + eps * f
+ return x.detach()
+
+
+def bptt_grad(Ffn, P, x0, R, tgt, T1, eps):
+ x = x0.clone()
+ for _ in range(T1):
+ x = x + eps * Ffn(x, P) # full graph
+ return torch.autograd.grad(cost(x, R, tgt), list(P.values()), allow_unused=True)
+
+
+def cosine(ga, gb):
+ fa = torch.cat([g.flatten() for g in ga])
+ fb = torch.cat([g.flatten() for g in gb])
+ return F.cosine_similarity(fa, fb, dim=0).item()
+
+
+def run(regime, T1=120, T2=30, eps=0.2, beta=0.02):
+ P = mk_params(regime)
+ Ffn = F_cons if regime == 'cons' else F_noncons
+ g = torch.Generator(device=dev).manual_seed(7)
+ x0 = torch.randn(B, N, D, generator=g, device=dev) * 0.1
+ R = torch.randn(N * D, 16, generator=g, device=dev) / math.sqrt(N * D)
+ tgt = torch.randn(B, 16, generator=g, device=dev)
+
+ xs = relax(Ffn, P, x0, T1, eps)
+ res = ((relax(Ffn, P, xs, 1, eps) - xs).norm() / (xs.norm() + 1e-8)).item()
+ v = torch.randn_like(xs)
+ aj = AJ_apply(Ffn, P, xs, v)
+ jv = torch.autograd.functional.jvp(lambda z: Ffn(z, P), xs, v)[1]
+ asym = (aj.norm() / (jv.norm() + 1e-8)).item()
+
+ gb = bptt_grad(Ffn, P, x0, R, tgt, T1, eps)
+ gn = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=False)
+ ga = ep_grad(Ffn, P, x0, R, tgt, T1, T2, eps, beta, aep=True)
+
+ print(f"\n===== regime={regime} (residual@x*={res:.1e}) =====")
+ print(f" Jacobian antisymmetry ||A_J v||/||J v|| = {asym:.3f} "
+ f"({'~conservative' if asym < 0.05 else 'NON-conservative'})")
+ print(f" cosine(naive_EP, BPTT) = {cosine(gn, gb):+.4f}")
+ print(f" cosine( AEP , BPTT) = {cosine(ga, gb):+.4f}")
+
+
+if __name__ == '__main__':
+ run('cons')
+ run('noncons')
diff --git a/scripts/aep_characterize.py b/scripts/aep_characterize.py
new file mode 100644
index 0000000..3e642c1
--- /dev/null
+++ b/scripts/aep_characterize.py
@@ -0,0 +1,157 @@
+"""
+Characterize AEP (non-conservative EP) on CET's attention, before porting to the LM.
+
+Controlled knob: attention scale s in force_z = -dE_rest/dz + s * RealAttn(z).
+ s=0 -> pure conservative reconstruction (A_J=0; EP exact)
+ s up -> attention dominates the force -> more non-conservative -> naive EP biased.
+Metric: cosine(EP-grad, BPTT-grad) on the attention params {WQ,WK,WV,WO} (the global
+cosine is diluted by the dominant conservative params, so we look at attention itself).
+The AEP correction is -s*(J_A v) on z, J_A = antisym Jacobian of RealAttn at the free eq.
+
+Sweeps: (1) s [non-conservativeness], (2) beta [nudge size], (3) T2 [nudge steps],
+ (4) T1 [free-phase convergence]. Plus: free-eq identical naive vs AEP, and cost.
+"""
+import argparse, math, time, torch, torch.nn.functional as F
+from cet_mvp import make_patch_mask, masked_cost, get_loaders
+from cet_aep import CETReal
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ATTN = ('WQ', 'WK', 'WV', 'WO')
+
+
+def force(model, xbar, z, y, s):
+ z = z.requires_grad_(True); y = y.requires_grad_(True)
+ gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True)
+ return -gz + s * model.real_attn(z), -gy
+
+
+def relax_free(model, xbar, z, y, s, T1, eps):
+ for _ in range(T1):
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, z, y, s)
+ fz, fy = fz.detach(), fy.detach()
+ with torch.no_grad():
+ z, y = z + eps * fz, y + eps * fy
+ return z.detach(), y.detach()
+
+
+def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep):
+ z, y = zs.clone(), ys.clone()
+ for _ in range(T2):
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, z, y, s)
+ fz, fy = fz.detach(), fy.detach()
+ yy = y.detach().requires_grad_(True)
+ gy, = torch.autograd.grad(masked_cost(yy, X, M), yy)
+ fy = fy - sign * beta * gy
+ if aep:
+ v = (z - zs).detach()
+ Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1]
+ fz = fz - s * (Jv - JTv) # -2 * s * 0.5 (J v - J^T v)
+ with torch.no_grad():
+ z, y = z + eps * fz, y + eps * fy
+ return z.detach(), y.detach()
+
+
+def vf_grad(model, xbar, s, T1, T2, eps, beta, aep):
+ zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps)
+ zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep)
+ zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep)
+ az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach()
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, zs.detach(), ys.detach(), s)
+ g = torch.autograd.grad((az * fz).sum() + (ay * fy).sum(),
+ list(model.parameters()), allow_unused=True)
+ return zs, g
+
+
+def bptt_grad(model, xbar, s, T1, eps):
+ z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True)
+ for _ in range(T1):
+ fz, fy = force(model, xbar, z, y, s)
+ z, y = z + eps * fz, y + eps * fy
+ return torch.autograd.grad(masked_cost(y, X, M) / M.sum(),
+ list(model.parameters()), allow_unused=True)
+
+
+def attn_cos(g, gb, names):
+ cs = [F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None]
+ return sum(cs) / len(cs)
+
+
+def global_cos(g, gb):
+ a = torch.cat([x.flatten() for x in g if x is not None])
+ b = torch.cat([x.flatten() for x, y in zip(g, gb) if x is not None and y is not None])
+ return F.cosine_similarity(a, b, dim=0).item()
+
+
+def measure(model, names, s, T1, T2, eps, beta):
+ gb = bptt_grad(model, XBAR, s, T1, eps)
+ zsn, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=False)
+ zsa, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, aep=True)
+ eq_id = (zsn - zsa).norm().item() / (zsn.norm().item() + 1e-9) # free eq identical?
+ return dict(naive=attn_cos(gn, gb, names), aep=attn_cos(ga, gb, names),
+ gnaive=global_cos(gn, gb), gaep=global_cos(ga, gb), eq_id=eq_id)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--dataset', default='fashionmnist')
+ ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1)
+ ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7)
+ ap.add_argument('--batch', type=int, default=32)
+ cfg = ap.parse_args()
+ torch.manual_seed(0)
+ model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=64, heads=4, dh=16, mem=128).to(dev)
+ names = [n for n, _ in model.named_parameters()]
+ trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset)
+ global X, M, XBAR
+ X, _ = next(iter(trl)); X = X.to(dev)
+ M = make_patch_mask(X.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev)
+ XBAR = X * (1 - M)
+
+ # intrinsic non-conservativeness of the attention map itself
+ zs, _ = relax_free(model, XBAR, *model.init_state(XBAR), 1.0, 120, 0.2)
+ v = torch.randn_like(zs)
+ Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1]
+ print(f"intrinsic attention-map antisymmetry ||A_J v||/||J v|| = "
+ f"{(0.5*(Jv-JTv)).norm().item()/(Jv.norm().item()+1e-9):.3f}")
+
+ base = dict(T1=120, T2=20, eps=0.2, beta=0.02)
+ print("\n[1] ATTENTION SCALE s (s=0 conservative -> larger = more non-conservative)")
+ print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'naive(glob)':>11} {'AEP(glob)':>10} | free-eq id")
+ for s in [0.25, 0.5, 1.0, 2.0, 4.0, 8.0]:
+ r = measure(model, names, s, base['T1'], base['T2'], base['eps'], base['beta'])
+ print(f"{s:6.2f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {r['gnaive']:>11.4f} {r['gaep']:>10.4f} | {r['eq_id']:.1e}")
+
+ print("\n[2] NUDGE STRENGTH beta (s=2, T2=20)")
+ print(f"{'beta':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
+ for beta in [0.005, 0.01, 0.02, 0.05, 0.1, 0.2]:
+ r = measure(model, names, 2.0, 120, 20, 0.2, beta)
+ print(f"{beta:6.3f} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
+
+ print("\n[3] NUDGE STEPS T2 (s=2, beta=0.02)")
+ print(f"{'T2':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
+ for T2 in [3, 5, 10, 20, 40]:
+ r = measure(model, names, 2.0, 120, T2, 0.2, 0.02)
+ print(f"{T2:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
+
+ print("\n[4] FREE-PHASE STEPS T1 (s=2; AEP uses A_J at the free eq)")
+ print(f"{'T1':>6} | {'naive(attn)':>11} {'AEP(attn)':>10}")
+ for T1 in [20, 40, 80, 120, 200]:
+ r = measure(model, names, 2.0, T1, 20, 0.2, 0.02)
+ print(f"{T1:6d} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
+
+ print("\n[5] COST (s=2, T1=120, T2=20)")
+ t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=False) for _ in range(3)]
+ torch.cuda.synchronize() if dev == 'cuda' else None; tn = (time.time()-t)/3
+ t = time.time(); [vf_grad(model, XBAR, 2.0, 120, 20, 0.2, 0.02, aep=True) for _ in range(3)]
+ torch.cuda.synchronize() if dev == 'cuda' else None; ta = (time.time()-t)/3
+ print(f" naive {tn*1000:.0f} ms/grad AEP {ta*1000:.0f} ms/grad overhead {ta/tn:.2f}x")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/aep_contractive.py b/scripts/aep_contractive.py
new file mode 100644
index 0000000..670feae
--- /dev/null
+++ b/scripts/aep_contractive.py
@@ -0,0 +1,52 @@
+"""F: make REAL attention EP-able by damping it into a contraction (keep it non-conservative).
+
+Attention term in the force becomes s*(attn(z) - c*z). The -c*z is damping that grows with s,
+pushing Re(eig(J_F)) < 0 (a stable fixed point) WITHOUT symmetrizing the Jacobian (the antisymmetric
+part is unchanged, so it stays non-conservative -> AEP still needed AND now applicable).
+
+We sweep (s, c) and report, using the validated projected-adjoint (option 1):
+ fwd resid : does a stable fixed point exist? (small = yes)
+ adj cos : projected-adjoint gradient fidelity vs BPTT on attention params
+Expected: c=0 breaks at high s (no fixed point, as before); c>=1 keeps resid small + fidelity high.
+"""
+import torch, aep_option1 as O
+from cet_aep import CETReal
+from cet_mvp import token_norm, make_patch_mask, get_loaders
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
+names = [n for n, _ in model.named_parameters()]
+orig_attn = model.real_attn # original (undamped) attention
+
+trl, _ = get_loaders(32, dataset='fashionmnist')
+X, _ = next(iter(trl)); X = X.to(dev)
+M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
+XBAR = X * (1 - M)
+O.X, O.M = X, M # masked_cost in option1 uses these globals
+
+
+def set_damp(c):
+ model.real_attn = orig_attn if c == 0 else (lambda z: orig_attn(z) - c * z)
+
+
+def resid(s, T1, eps=0.2):
+ zs, ys = O.relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps)
+ with torch.enable_grad():
+ zr, yr = zs.requires_grad_(True), ys.requires_grad_(True)
+ fz, _ = O.force(model, XBAR, zr, yr, s)
+ zn = token_norm(zs + eps * fz.detach())
+ return ((zn - zs).norm() / (zs.norm() + 1e-9)).item()
+
+
+print("Contractive (damped) non-conservative attention — does it restore a fixed point + EP fidelity?")
+print(f"{'s':>5} {'c':>4} | {'fwd resid':>9} {'adj cos(attn)':>13} {'glob':>7}")
+for s in [1.0, 2.0, 4.0, 8.0]:
+ for c in [0.0, 1.0, 2.0]:
+ set_damp(c)
+ r = resid(s, 250)
+ gb = O.bptt_grad(model, XBAR, s, 250, 0.2)
+ ga = O.adjoint_grad(model, XBAR, s, 250, 0.2, 250)
+ a, g = O.cosines(ga, gb, names)
+ print(f"{s:>5.1f} {c:>4.1f} | {r:>9.2e} {a:>13.3f} {g:>7.3f}")
+ print()
diff --git a/scripts/aep_contractive2.py b/scripts/aep_contractive2.py
new file mode 100644
index 0000000..f1d38a8
--- /dev/null
+++ b/scripts/aep_contractive2.py
@@ -0,0 +1,41 @@
+"""F (v2): make real attention EP-able via UNCONSTRAINED dynamics + damping (no projection).
+
+The projection (C/F-v1) fought radial damping and broke the VF. Drop it: unconstrained AEP
+already has clean theory (0.99 fidelity) but diverges at high s for lack of confinement.
+Add damping that scales with s: attention term = s*(attn(z) - c*z). Fixed point
+z* = [s*attn(z*) + enc]/(4 + s*c) -> attention still sets the direction, but -(4+sc)z makes
+it a contraction (so a stable fixed point exists). Small eps needed (the linear part is stiff).
+
+Reuses aep_characterize's UNCONSTRAINED, AEP-validated machinery; monkeypatches attention to the
+damped version. Reports naive vs AEP attention-param cosine vs BPTT, and whether it stayed finite.
+"""
+import math, torch, aep_characterize as A
+from cet_aep import CETReal
+from cet_mvp import make_patch_mask, get_loaders
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
+names = [n for n, _ in model.named_parameters()]
+orig = model.real_attn
+trl, _ = get_loaders(32, dataset='fashionmnist')
+X, _ = next(iter(trl)); X = X.to(dev)
+M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
+A.X, A.M, A.XBAR = X, M, X * (1 - M)
+
+
+def setc(c):
+ model.real_attn = orig if c == 0 else (lambda z: orig(z) - c * z)
+
+
+# small eps for the stiff damped linear part; more free steps to converge
+EPS, T1, T2, BETA = 0.05, 400, 40, 0.02
+print(f"UNCONSTRAINED + damping, eps={EPS} T1={T1} T2={T2}")
+print(f"{'s':>5} {'c':>4} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'finite?':>7}")
+for s in [2.0, 4.0, 8.0]:
+ for c in [0.0, 1.0, 2.0]:
+ setc(c)
+ r = A.measure(model, names, s, T1, T2, EPS, BETA)
+ fin = not (math.isnan(r['aep']) or math.isnan(r['naive']))
+ print(f"{s:>5.1f} {c:>4.1f} | {r['naive']:>11.3f} {r['aep']:>10.3f} | {str(fin):>7}")
+ print()
diff --git a/scripts/aep_depth.py b/scripts/aep_depth.py
new file mode 100644
index 0000000..c202a0c
--- /dev/null
+++ b/scripts/aep_depth.py
@@ -0,0 +1,30 @@
+"""B: does AEP gradient fidelity degrade as the non-conservative attention gets DEEPER?
+Stack K residual attention sub-layers (weight-tied) inside the force; measure naive vs
+AEP attention-param cosine vs BPTT, at fixed scale s."""
+import torch, aep_characterize as A
+from cet_aep import CETReal
+from cet_mvp import make_patch_mask, get_loaders
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+torch.manual_seed(0)
+model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
+names = [n for n, _ in model.named_parameters()]
+trl, _ = get_loaders(32, dataset='fashionmnist')
+X, _ = next(iter(trl)); X = X.to(dev)
+M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
+A.X, A.M, A.XBAR = X, M, X * (1 - M)
+
+base = model.real_attn
+def deep(K):
+ def f(z):
+ h = z
+ for _ in range(K):
+ h = h + base(h)
+ return h - z
+ return f
+
+print(f"{'depth K':>8} | {'naive(attn)':>11} {'AEP(attn)':>10}")
+for K in [1, 2, 3, 4]:
+ model.real_attn = deep(K)
+ r = A.measure(model, names, 1.0, 120, 30, 0.2, 0.02) # s=1, T2=30 (enough per [3])
+ print(f"{K:>8} | {r['naive']:>11.3f} {r['aep']:>10.3f}")
diff --git a/scripts/aep_option1.py b/scripts/aep_option1.py
new file mode 100644
index 0000000..65583d4
--- /dev/null
+++ b/scripts/aep_option1.py
@@ -0,0 +1,115 @@
+"""option 1: CORRECT gradient for non-conservative attention UNDER the token-norm constraint.
+
+Implicit differentiation of the projected fixed-point map G(x) = Pi(x + eps F(x)):
+ adjoint a <- J_G^T a + g , J_G^T = (I + eps J_F^T) Pi'^T , g = dC/dx*
+ gradient dL/dtheta = eps * < Pi'^T a , dF/dtheta(x*) >
+
+Built from LOCAL pieces only (this is the projected analogue of EP's nudged adjoint):
+ Pi'^T : vjp(token_norm, u, .) (the LN/projection Jacobian = LayerNormProjectedSurrogate)
+ J_F^T : -Hess(E_rest).b (symmetric, via HVP) + s * vjp(real_attn, z*, .) (the non-conservative bit)
+Validation: cosine vs BPTT-through-the-projected-relaxation (ground truth). C lost fidelity; this should recover it.
+"""
+import torch, torch.nn.functional as F, math
+from cet_mvp import token_norm, make_patch_mask, masked_cost, get_loaders
+from cet_aep import CETReal
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ATTN = ('WQ', 'WK', 'WV', 'WO')
+
+
+def force(model, xbar, z, y, s, cg=False):
+ gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=cg)
+ return -gz + s * model.real_attn(z), -gy
+
+
+def relax_free(model, xbar, z, y, s, T1, eps):
+ for _ in range(T1):
+ with torch.enable_grad():
+ zr, yr = z.requires_grad_(True), y.requires_grad_(True)
+ fz, fy = force(model, xbar, zr, yr, s)
+ fz, fy = fz.detach(), fy.detach()
+ with torch.no_grad():
+ z, y = token_norm(z + eps * fz), y + eps * fy
+ return z.detach(), y.detach()
+
+
+def adjoint_grad(model, xbar, s, T1, eps, Tadj):
+ zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps)
+ # pre-projection point u for Pi' ; cost grad g=(0, dC/dy)
+ zr, yr = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True)
+ fz, fy = force(model, xbar, zr, yr, s)
+ uz = (zs + eps * fz).detach()
+ yc = ys.detach().requires_grad_(True)
+ gy_c, = torch.autograd.grad(masked_cost(yc, X, M) / M.sum(), yc)
+ gy_c = gy_c.detach()
+
+ az, ay = torch.zeros_like(zs), gy_c.clone() # init adjoint at g (cost grad)
+ for _ in range(Tadj):
+ bz = torch.autograd.functional.vjp(token_norm, uz, az)[1] # Pi'^T a (z); y identity
+ by = ay
+ # J_F^T b = -Hess(E_rest).b + s * vjp(real_attn, zs, bz)
+ zr2, yr2 = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True)
+ gz2, gy2 = torch.autograd.grad(model.E_rest(xbar, zr2, yr2), [zr2, yr2], create_graph=True)
+ hz, hy = torch.autograd.grad((gz2 * bz).sum() + (gy2 * by).sum(), [zr2, yr2])
+ av = torch.autograd.functional.vjp(model.real_attn, zs, bz)[1]
+ JFt_z, JFt_y = -hz + s * av, -hy
+ az = (bz + eps * JFt_z + torch.zeros_like(zs)).detach()
+ ay = (by + eps * JFt_y + gy_c).detach()
+
+ # gradient: eps * d/dtheta < Pi'^T a , F(x*, theta) >
+ bz = torch.autograd.functional.vjp(token_norm, uz, az)[1].detach()
+ by = ay.detach()
+ zr3, yr3 = zs.detach().requires_grad_(True), ys.detach().requires_grad_(True)
+ gz3, gy3 = torch.autograd.grad(model.E_rest(xbar, zr3, yr3), [zr3, yr3], create_graph=True)
+ Fz = -gz3 + s * model.real_attn(zr3)
+ Fy = -gy3
+ contr = eps * ((bz * Fz).sum() + (by * Fy).sum())
+ return torch.autograd.grad(contr, list(model.parameters()), allow_unused=True)
+
+
+def bptt_grad(model, xbar, s, T1, eps):
+ z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True)
+ for _ in range(T1):
+ fz, fy = force(model, xbar, z, y, s, cg=True)
+ z, y = token_norm(z + eps * fz), y + eps * fy
+ return torch.autograd.grad(masked_cost(y, X, M) / M.sum(),
+ list(model.parameters()), allow_unused=True)
+
+
+def cosines(g, gb, names):
+ c = lambda a, b: F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ at = [c(a, b) for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None]
+ A = torch.cat([x.flatten() for x in g if x is not None])
+ B = torch.cat([y.flatten() for x, y in zip(g, gb) if x is not None and y is not None])
+ return (sum(at) / len(at) if at else float('nan')), c(A, B)
+
+
+def main():
+ torch.manual_seed(0)
+ model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
+ names = [n for n, _ in model.named_parameters()]
+ trl, _ = get_loaders(32, dataset='fashionmnist')
+ global X, M, XBAR
+ X, _ = next(iter(trl)); X = X.to(dev)
+ M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
+ XBAR = X * (1 - M)
+ def resid(s, T1, eps=0.2):
+ zs, ys = relax_free(model, XBAR, *model.init_state(XBAR), s, T1, eps)
+ with torch.enable_grad():
+ zr, yr = zs.requires_grad_(True), ys.requires_grad_(True)
+ fz, fy = force(model, XBAR, zr, yr, s)
+ zn = token_norm(zs + eps * fz.detach())
+ return ((zn - zs).norm() / (zs.norm() + 1e-9)).item()
+
+ print("PROJECTED-ADJOINT (option 1) vs BPTT — is the s>=2 break convergence or no-fixed-point?")
+ print(f"{'s':>5} {'T1=Tadj':>8} | {'attn cos':>9} {'glob cos':>9} | {'fwd resid':>9}")
+ for s in [0.5, 1.0, 2.0]:
+ for it in [120, 400]:
+ gb = bptt_grad(model, XBAR, s, it, 0.2)
+ ga = adjoint_grad(model, XBAR, s, it, 0.2, it)
+ a, g = cosines(ga, gb, names)
+ print(f"{s:5.1f} {it:>8} | {a:>9.3f} {g:>9.3f} | {resid(s, it):>9.2e}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/aep_projected.py b/scripts/aep_projected.py
new file mode 100644
index 0000000..af8891e
--- /dev/null
+++ b/scripts/aep_projected.py
@@ -0,0 +1,125 @@
+"""C / option 1: PROJECTED AEP — non-conservative EP on the token-norm constraint manifold.
+
+Two fixes over the unconstrained version:
+ (1) STABILITY: relax with the token-norm projection z <- Pi(z + eps F) (bounds z;
+ this is what made plain CET stable). Lets large-s / deep attention stop diverging.
+ (2) CORRECT GRADIENT under the constraint: the VF contraction must be projected onto the
+ TANGENT space of the manifold. The tangent projector at a normalized token z is
+ P_z(v) = v - mean(v) - mean(v*z) * z
+ (exactly the local-transformer's LayerNormProjectedSurrogate). Without it the VF
+ estimator picks up the normal force and collapses (energy-mode cosine ~0.002).
+
+Param-gradient: dL/dtheta = <a_z, P_z*( dF_z/dtheta )> + <a_y, dF_y/dtheta>,
+ a = (state_-b - state_+b)/(2 beta).
+AEP correction (nudged phase, on z): -s (J v - J^T v) of RealAttn, then projected.
+"""
+import argparse, math, torch, torch.nn.functional as F
+from cet_mvp import token_norm, make_patch_mask, masked_cost, get_loaders
+from cet_aep import CETReal
+
+dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ATTN = ('WQ', 'WK', 'WV', 'WO')
+
+
+def P_tan(z, v): # tangent projection at normalized token z
+ v = v - v.mean(-1, keepdim=True)
+ zz = (z * z).mean(-1, keepdim=True).clamp_min(1e-6)
+ return v - ((v * z).mean(-1, keepdim=True) / zz) * z
+
+
+def force(model, xbar, z, y, s):
+ z = z.requires_grad_(True); y = y.requires_grad_(True)
+ gz, gy = torch.autograd.grad(model.E_rest(xbar, z, y), [z, y], create_graph=True)
+ return -gz + s * model.real_attn(z), -gy
+
+
+def relax_free(model, xbar, z, y, s, T1, eps):
+ for _ in range(T1):
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, z, y, s); fz, fy = fz.detach(), fy.detach()
+ with torch.no_grad():
+ z = token_norm(z + eps * fz); y = y + eps * fy
+ return z.detach(), y.detach()
+
+
+def relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, sign, aep):
+ z, y = zs.clone(), ys.clone()
+ for _ in range(T2):
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, z, y, s); fz, fy = fz.detach(), fy.detach()
+ yy = y.detach().requires_grad_(True)
+ gy, = torch.autograd.grad(masked_cost(yy, X, M), yy)
+ fy = fy - sign * beta * gy
+ if aep:
+ v = (z - zs).detach()
+ Jv = torch.autograd.functional.jvp(model.real_attn, zs, v)[1]
+ JTv = torch.autograd.functional.vjp(model.real_attn, zs, v)[1]
+ fz = fz - s * (Jv - JTv)
+ with torch.no_grad():
+ z = token_norm(z + eps * fz); y = y + eps * fy
+ return z.detach(), y.detach()
+
+
+def vf_grad(model, xbar, s, T1, T2, eps, beta, aep):
+ zs, ys = relax_free(model, xbar, *model.init_state(xbar), s, T1, eps)
+ zp, yp = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, +1, aep)
+ zm, ym = relax_nudged(model, xbar, zs, ys, s, T2, eps, beta, -1, aep)
+ az = P_tan(zs, ((zm - zp) / (2 * beta))).detach() # adjoint in tangent space
+ ay = ((ym - yp) / (2 * beta)).detach()
+ with torch.enable_grad():
+ fz, fy = force(model, xbar, zs.detach(), ys.detach(), s)
+ s_ = (az * P_tan(zs, fz)).sum() + (ay * fy).sum() # projected contraction
+ g = torch.autograd.grad(s_, list(model.parameters()), allow_unused=True)
+ return zs, g
+
+
+def bptt_grad(model, xbar, s, T1, eps):
+ z, y = model.init_state(xbar); z, y = z.requires_grad_(True), y.requires_grad_(True)
+ for _ in range(T1):
+ fz, fy = force(model, xbar, z, y, s)
+ z = token_norm(z + eps * fz); y = y + eps * fy
+ return torch.autograd.grad(masked_cost(y, X, M) / M.sum(),
+ list(model.parameters()), allow_unused=True)
+
+
+def cosines(g, gb, names):
+ def c(a, b): return F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ at = [c(a, b) for n, a, b in zip(names, g, gb) if n in ATTN and a is not None and b is not None]
+ A = torch.cat([x.flatten() for x in g if x is not None])
+ B = torch.cat([y.flatten() for x, y in zip(g, gb) if x is not None and y is not None])
+ return (sum(at) / len(at) if at else float('nan')), c(A, B)
+
+
+def measure(model, names, s, T1, T2, eps, beta):
+ gb = bptt_grad(model, XBAR, s, T1, eps)
+ _, gn = vf_grad(model, XBAR, s, T1, T2, eps, beta, False)
+ zs, ga = vf_grad(model, XBAR, s, T1, T2, eps, beta, True)
+ an, gng = cosines(gn, gb, names)
+ aa, gag = cosines(ga, gb, names)
+ fin = torch.isfinite(zs).all().item()
+ return an, aa, gng, gag, fin
+
+
+def main():
+ torch.manual_seed(0)
+ model = CETReal(28, 1, 7, 7, D=64, heads=4, dh=16, mem=128).to(dev)
+ names = [n for n, _ in model.named_parameters()]
+ trl, _ = get_loaders(32, dataset='fashionmnist')
+ global X, M, XBAR
+ X, _ = next(iter(trl)); X = X.to(dev)
+ M = make_patch_mask(X.size(0), model.gh, 7, 7, 28, 28, 0.5, dev)
+ XBAR = X * (1 - M)
+
+ print("SANITY s=0 (pure conservative): projected-VF global cosine should be ~1")
+ _, _, gnaive, _, _ = measure(model, names, 0.0, 120, 20, 0.2, 0.02)
+ print(f" s=0 global cosine = {gnaive:.4f}\n")
+
+ print("PROJECTED AEP across attention scale s (T1=120 T2=30 beta=0.02)")
+ print(f"{'s':>6} | {'naive(attn)':>11} {'AEP(attn)':>10} | {'finite?':>7} (unproj. broke at s>=4)")
+ for s in [0.5, 1.0, 2.0, 4.0, 8.0, 16.0]:
+ an, aa, gn, ga, fin = measure(model, names, s, 120, 30, 0.2, 0.02)
+ print(f"{s:6.2f} | {an:>11.3f} {aa:>10.3f} | {str(bool(fin)):>7}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/ask_fugu.py b/scripts/ask_fugu.py
new file mode 100644
index 0000000..b3ffa99
--- /dev/null
+++ b/scripts/ask_fugu.py
@@ -0,0 +1,24 @@
+import json, urllib.request, urllib.error, os, sys
+key = open(os.path.expanduser("~/.codex/sakana.key")).read().strip()
+brief = open("/home/yurenh2/ept/PHYSICS_QUESTIONS_FOR_DEEP_REASONING.md").read()
+prompt = brief + "\n\n---\nAnswer Q1 through Q7. For each: (a) is it a FUNDAMENTAL obstruction (cite/sketch the no-go) or an ENGINEERING gap (sketch the construction); (b) the physical-realizability verdict (local, forward-only, no backward pass?); (c) the cheapest experiment on our simulator that would falsify your proposed mechanism. Be rigorous, specific, and decisive."
+payload = {"model": "fugu-ultra", "input": prompt, "reasoning": {"effort": "xhigh"}}
+req = urllib.request.Request("https://api.sakana.ai/v1/responses",
+ data=json.dumps(payload).encode(),
+ headers={"Content-Type": "application/json", "Authorization": "Bearer " + key})
+try:
+ data = json.load(urllib.request.urlopen(req, timeout=3000))
+except urllib.error.HTTPError as e:
+ print("HTTPError", e.code); print(e.read().decode()[:3000]); sys.exit(1)
+except Exception as e:
+ print("ERR", repr(e)); sys.exit(1)
+texts = []
+for item in data.get("output", []):
+ if item.get("type") == "message":
+ for c in item.get("content", []):
+ if c.get("type") in ("output_text", "text"): texts.append(c.get("text", ""))
+out = "\n".join(t for t in texts if t) or data.get("output_text", "") or ("RAW:\n" + json.dumps(data)[:4000])
+open("/home/yurenh2/ept/FUGU_PHYSICS_ANSWER.md", "w").write(out)
+print("=== USAGE ===", json.dumps(data.get("usage", {})))
+print("=== FUGU-ULTRA ANSWER (saved to FUGU_PHYSICS_ANSWER.md) ===")
+print(out)
diff --git a/scripts/bp_transformer.py b/scripts/bp_transformer.py
new file mode 100644
index 0000000..7c9b543
--- /dev/null
+++ b/scripts/bp_transformer.py
@@ -0,0 +1,141 @@
+"""
+Vanilla backprop Transformer baseline for the SAME masked-image-completion task,
+so we can compare against CET-EP and CET-TBPTE on the identical metric
+(masked-patch pixel MSE on CIFAR-10, images in [-1,1]).
+
+Standard recipe: conv patch-embed + learned pos-embed, MAE-style learned mask
+token on occluded patches, N standard pre-LN transformer blocks (MHA + FFN),
+linear pixel head, MSE loss on masked patches only. Trained with normal Adam/BP.
+"""
+import argparse, os, time, json, math
+import torch, torch.nn as nn, torch.nn.functional as F
+from cet_mvp import get_loaders, make_patch_mask # reuse data + masking
+
+
+class Block(nn.Module):
+ def __init__(self, D, heads, mlp_ratio):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(D)
+ self.attn = nn.MultiheadAttention(D, heads, batch_first=True)
+ self.ln2 = nn.LayerNorm(D)
+ self.mlp = nn.Sequential(nn.Linear(D, int(mlp_ratio * D)), nn.GELU(),
+ nn.Linear(int(mlp_ratio * D), D))
+
+ def forward(self, x):
+ h = self.ln1(x)
+ x = x + self.attn(h, h, h, need_weights=False)[0]
+ x = x + self.mlp(self.ln2(x))
+ return x
+
+
+class BPTransformer(nn.Module):
+ def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4,
+ depth=1, mlp_ratio=2.0):
+ super().__init__()
+ self.ch, self.patch, self.stride = ch, patch, stride
+ gh = (img - patch) // stride + 1
+ self.gh, self.N, self.pdim = gh, gh * gh, ch * patch * patch
+ self.embed = nn.Conv2d(ch, D, patch, stride=stride)
+ self.pos = nn.Parameter(torch.zeros(1, self.N, D)); nn.init.normal_(self.pos, std=0.02)
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, D)); nn.init.normal_(self.mask_token, std=0.02)
+ self.blocks = nn.ModuleList([Block(D, heads, mlp_ratio) for _ in range(depth)])
+ self.ln = nn.LayerNorm(D)
+ self.head = nn.Linear(D, self.pdim)
+
+ def patchify(self, x): # (B,C,H,W)->(B,N,pdim)
+ p, s = self.patch, self.stride
+ u = x.unfold(2, p, s).unfold(3, p, s) # B,C,gh,gw,p,p
+ return u.permute(0, 2, 3, 1, 4, 5).reshape(x.size(0), self.N, self.pdim)
+
+ def forward(self, xbar, pm): # pm: (B,N) 1=masked
+ t = self.embed(xbar).flatten(2).transpose(1, 2) # (B,N,D)
+ t = torch.where(pm.unsqueeze(-1).bool(), self.mask_token, t) + self.pos
+ for b in self.blocks:
+ t = b(t)
+ return self.head(self.ln(t)) # (B,N,pdim)
+
+
+def patch_mask_bool(B, gh, ratio, device, gen=None):
+ npatch = gh * gh; nmask = int(round(ratio * npatch))
+ idx = torch.rand(B, npatch, device=device, generator=gen).argsort(1)
+ pm = torch.zeros(B, npatch, device=device)
+ pm.scatter_(1, idx[:, :nmask], 1.0)
+ return pm # (B,N)
+
+
+def masked_patch_mse(pred, true, pm):
+ m = pm.unsqueeze(-1)
+ return ((pred - true) ** 2 * m).sum() / (m.sum() * pred.size(-1)).clamp_min(1.0)
+
+
+@torch.no_grad()
+def evaluate(model, loader, cfg, device, max_batches=100):
+ model.eval(); tot, n = 0.0, 0
+ gen = torch.Generator(device=device).manual_seed(0)
+ for i, (x, _) in enumerate(loader):
+ if i >= max_batches:
+ break
+ x = x.to(device)
+ pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device, gen)
+ M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1)
+ xbar = x * (1 - M)
+ pred = model(xbar, pm)
+ tot += masked_patch_mse(pred, model.patchify(x), pm).item() * x.size(0); n += x.size(0)
+ model.train(); return tot / n
+
+
+def train(cfg):
+ device = cfg.device; torch.manual_seed(cfg.seed)
+ model = BPTransformer(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads,
+ cfg.depth, cfg.mlp_ratio).to(device)
+ opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min)
+ trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset)
+ print(f"[bp] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K "
+ f"depth={cfg.depth} D={cfg.D} mlp={cfg.mlp_ratio}", flush=True)
+ step, t0, run = 0, time.time(), 0.0
+ while step < cfg.steps:
+ for x, _ in trl:
+ if step >= cfg.steps:
+ break
+ x = x.to(device, non_blocking=True)
+ pm = patch_mask_bool(x.size(0), model.gh, cfg.mask_ratio, device)
+ M = pm.view(-1, model.gh, model.gh).repeat_interleave(cfg.patch, 1).repeat_interleave(cfg.patch, 2).unsqueeze(1)
+ xbar = x * (1 - M)
+ pred = model(xbar, pm)
+ loss = masked_patch_mse(pred, model.patchify(x), pm)
+ opt.zero_grad(set_to_none=True); loss.backward(); opt.step(); sched.step()
+ run += loss.item(); step += 1
+ if step % cfg.log_every == 0:
+ print(f"step {step:5d}/{cfg.steps} | train masked-MSE {run/cfg.log_every:.5f} "
+ f"| {step/(time.time()-t0):.1f} it/s", flush=True); run = 0.0
+ if step % cfg.eval_every == 0 or step == cfg.steps:
+ print(f" >> [eval] step {step} test masked-MSE {evaluate(model, tel, cfg, device, 20):.5f}", flush=True)
+ final = evaluate(model, tel, cfg, device, 100)
+ os.makedirs(cfg.out, exist_ok=True)
+ json.dump({'mode': 'bp_transformer', 'final_test_masked_mse': final, 'steps': cfg.steps,
+ 'params_K': sum(p.numel() for p in model.parameters()) / 1e3},
+ open(os.path.join(cfg.out, 'result_bp_transformer.json'), 'w'), indent=2)
+ print(f"[bp] DONE final test masked-MSE = {final:.5f}", flush=True)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10')
+ p.add_argument('--steps', type=int, default=3000); p.add_argument('--batch', type=int, default=128)
+ p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3)
+ p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8)
+ p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4)
+ p.add_argument('--depth', type=int, default=1); p.add_argument('--mlp_ratio', type=float, default=2.0)
+ p.add_argument('--mask_ratio', type=float, default=0.5)
+ p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6)
+ p.add_argument('--wd', type=float, default=3e-5)
+ p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=500)
+ p.add_argument('--seed', type=int, default=0); p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs')
+ p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ cfg = p.parse_args()
+ print('config:', vars(cfg), flush=True); train(cfg)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/cet_aep.py b/scripts/cet_aep.py
new file mode 100644
index 0000000..44e8922
--- /dev/null
+++ b/scripts/cet_aep.py
@@ -0,0 +1,272 @@
+"""
+AEP applied to CET's attention: replace CET's conservative energy-attention E^att
+with a REAL (non-conservative) transformer attention inside the CET, and use the
+AEP correction so EP still recovers the true gradient.
+
+CET state = (tokens z, reconstruction y). The conservative part
+ E_rest = E_enc + E_pos + E_mem + E_dec (scalar -> symmetric Jacobian)
+keeps its energy-gradient force. The token force gets its attention term from:
+ energy mode : -dE^att/dz (conservative; tied value; this is plain CET)
+ real mode : RealAttn(z) = WO softmax(QK^T/sqrt dh)(z WV) (non-conservative)
+
+Because only RealAttn is non-conservative, the full-force antisymmetric Jacobian
+A_J reduces to the antisymmetric part of dRealAttn/dz alone -> AEP correction is
+ force_z += -(J_A z~ - J_A^T z~) , z~=z-z* , J_A = dRealAttn/dz at z*
+(clean jvp/vjp on RealAttn, no nested autograd).
+
+We compare parameter-gradient quality vs ground-truth BPTT for:
+ energy / naive-EP (conservative CET; sanity, should be ~BPTT)
+ real / naive-EP (non-conservative; expected biased)
+ real / AEP (non-conservative + correction; expected ~BPTT)
+"""
+import argparse, math, time, json, os, torch, torch.nn as nn, torch.nn.functional as F
+from cet_mvp import token_norm, make_patch_mask, masked_cost, masked_mse, get_loaders
+
+
+class CETReal(nn.Module):
+ def __init__(self, img=32, ch=3, patch=8, stride=8, D=64, heads=4, dh=16,
+ mem=128, gamma=0.25):
+ super().__init__()
+ self.ch, self.patch, self.stride, self.D = ch, patch, stride, D
+ self.heads, self.dh, self.gamma = heads, dh, gamma
+ gh = (img - patch) // stride + 1
+ self.gh, self.N = gh, gh * gh
+ self.damp = 0.0 # contraction damping c: real_attn returns attn(z) - c*z
+ self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch))
+ self.benc = nn.Parameter(torch.zeros(D))
+ self.bpos = nn.Parameter(torch.zeros(self.N, D))
+ self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch))
+ self.bdec = nn.Parameter(torch.zeros(ch))
+ self.Wmem = nn.Parameter(torch.empty(D, mem))
+ # attention: WQ/WK used by both; WV/WO only by the real (non-conservative) path
+ self.WQ = nn.Parameter(torch.empty(heads, dh, D))
+ self.WK = nn.Parameter(torch.empty(heads, dh, D))
+ self.WV = nn.Parameter(torch.empty(heads, dh, D))
+ self.WO = nn.Parameter(torch.empty(D, heads * dh))
+ nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5
+ nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5
+ for w in (self.WQ, self.WK, self.WV):
+ nn.init.normal_(w, std=1.0 / math.sqrt(D))
+ nn.init.normal_(self.Wmem, std=0.3 / math.sqrt(D)) # small: keep energy bounded-below
+ nn.init.normal_(self.WO, std=1.0 / math.sqrt(heads * dh))
+
+ def encode(self, xbar):
+ return F.conv2d(xbar, self.Wenc, stride=self.stride).flatten(2).transpose(1, 2)
+
+ def decode_conv(self, y):
+ return F.conv2d(y, self.Wdec, stride=self.stride).flatten(2).transpose(1, 2)
+
+ def E_rest(self, xbar, z, y): # conservative scalar (no attention)
+ enc = self.encode(xbar)
+ E = 2.0 * (z ** 2).sum() - (enc * z).sum() - (z * self.benc).sum() - (z * self.bpos).sum()
+ proj = torch.einsum('bnd,dm->bnm', z, self.Wmem)
+ E = E - (F.relu(proj) ** 2).sum()
+ dc = self.decode_conv(y)
+ E = E + 0.5 * (y ** 2).sum() - (dc * z).sum() - (y * self.bdec[None, :, None, None]).sum()
+ return E
+
+ def E_att(self, z): # conservative LogSumExp energy (tied value)
+ Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ)
+ K = torch.einsum('bnd,hjd->bhnj', z, self.WK)
+ A = torch.einsum('bhmj,bhnj->bhmn', Q, K)
+ return -(1.0 / self.gamma) * torch.logsumexp(self.gamma * A, dim=-1).sum()
+
+ def real_attn(self, z): # NON-conservative real attention force
+ B = z.size(0)
+ q = torch.einsum('bnd,hjd->bhnj', z, self.WQ)
+ k = torch.einsum('bnd,hjd->bhnj', z, self.WK)
+ v = torch.einsum('bnd,hjd->bhnj', z, self.WV)
+ A = torch.softmax((q @ k.transpose(-2, -1)) / math.sqrt(self.dh), dim=-1)
+ o = (A @ v).transpose(1, 2).reshape(B, self.N, self.heads * self.dh)
+ return o @ self.WO.t() - self.damp * z # -c*z: symmetric -> contraction, A_J unchanged
+
+ def force(self, xbar, z, y, mode):
+ """Return (force_z, force_y). force = -dE/dstate (+ real attention if mode='real')."""
+ z = z.requires_grad_(True); y = y.requires_grad_(True)
+ if mode == 'energy':
+ E = self.E_rest(xbar, z, y) + self.E_att(z)
+ gz, gy = torch.autograd.grad(E, [z, y], create_graph=True)
+ return -gz, -gy
+ else:
+ E = self.E_rest(xbar, z, y)
+ gz, gy = torch.autograd.grad(E, [z, y], create_graph=True)
+ return -gz + self.real_attn(z), -gy
+
+ def init_state(self, xbar):
+ return token_norm(self.encode(xbar)).detach(), xbar.clone().detach()
+
+
+def relax(model, xbar, z, y, steps, eps, mode, x=None, M=None, beta=0.0, aep=False, zstar=None):
+ for _ in range(steps):
+ with torch.enable_grad():
+ fz, fy = model.force(xbar, z, y, mode)
+ fz, fy = fz.detach(), fy.detach()
+ if beta != 0.0: # nudge on the output y
+ yy = y.detach().requires_grad_(True)
+ gy, = torch.autograd.grad(masked_cost(yy, x, M), yy)
+ fy = fy - beta * gy
+ if aep: # AEP correction on z (attention block only)
+ v = (z - zstar).detach()
+ fa = lambda zz: model.real_attn(zz)
+ Jv = torch.autograd.functional.jvp(fa, zstar, v)[1]
+ JTv = torch.autograd.functional.vjp(fa, zstar, v)[1]
+ corr = Jv - JTv # = 2 * 0.5 (J v - J^T v)
+ cn, fn = corr.norm(), fz.norm() + 1e-8 # clip so correction can't dominate -> no blow-up
+ if cn > fn:
+ corr = corr * (fn / cn)
+ fz = fz - corr
+ with torch.no_grad():
+ z = z + eps * fz # unconstrained (0.5||z||^2 in E_rest keeps it bounded)
+ y = y + eps * fy
+ return z.detach(), y.detach()
+
+
+def vf_param_grad(model, xbar, x, M, mode, T1, T2, eps, beta, aep):
+ z0, y0 = model.init_state(xbar)
+ zs, ys = relax(model, xbar, z0, y0, T1, eps, mode)
+ zp, yp = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, +beta, aep, zs)
+ zm, ym = relax(model, xbar, zs.clone(), ys.clone(), T2, eps, mode, x, M, -beta, aep, zs)
+ az, ay = ((zm - zp) / (2 * beta)).detach(), ((ym - yp) / (2 * beta)).detach()
+ with torch.enable_grad():
+ fz, fy = model.force(xbar, zs.detach(), ys.detach(), mode)
+ s = (az * fz).sum() + (ay * fy).sum()
+ grads = torch.autograd.grad(s, list(model.parameters()), allow_unused=True, retain_graph=False)
+ return grads
+
+
+def bptt_param_grad(model, xbar, x, M, mode, T1, eps):
+ z, y = model.init_state(xbar)
+ z, y = z.requires_grad_(True), y.requires_grad_(True)
+ for _ in range(T1):
+ fz, fy = model.force(xbar, z, y, mode)
+ z = z + eps * fz
+ y = y + eps * fy
+ L = masked_cost(y, x, M) / M.sum()
+ return torch.autograd.grad(L, list(model.parameters()), allow_unused=True)
+
+
+def cos(ga, gb, names):
+ fa, fb = [], []
+ per = {}
+ for n, a, b in zip(names, ga, gb):
+ if a is None or b is None:
+ continue
+ fa.append(a.flatten()); fb.append(b.flatten())
+ per[n] = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ g = F.cosine_similarity(torch.cat(fa), torch.cat(fb), dim=0).item()
+ return g, per
+
+
+def evaluate(model, loader, cfg, dev, mode='real', max_batches=40):
+ tot, n = 0.0, 0
+ gen = torch.Generator(device=dev).manual_seed(0)
+ for i, (x, _) in enumerate(loader):
+ if i >= max_batches:
+ break
+ x = x.to(dev)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev, gen)
+ xbar = x * (1 - M)
+ z, y = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, mode)
+ tot += masked_mse(y, x, M) * x.size(0); n += x.size(0)
+ return tot / n
+
+
+def fidelity(cfg, model, dev):
+ names = [n for n, _ in model.named_parameters()]
+ trl, _ = get_loaders(cfg.batch, dataset=cfg.dataset)
+ x, _ = next(iter(trl)); x = x.to(dev)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev)
+ xbar = x * (1 - M)
+ zs, ys = relax(model, xbar, *model.init_state(xbar), cfg.T1, cfg.eps, 'real')
+ v = torch.randn_like(zs)
+ Jv = torch.autograd.functional.jvp(lambda z: model.real_attn(z), zs, v)[1]
+ JTv = torch.autograd.functional.vjp(lambda z: model.real_attn(z), zs, v)[1]
+ asym = (0.5 * (Jv - JTv)).norm().item() / (Jv.norm().item() + 1e-8)
+ print(f"real-attention Jacobian antisymmetry = {asym:.3f}\n")
+ for mode, aep, label in [('energy', False, 'energy/naive (sanity)'),
+ ('real', False, 'real/naive (biased)'),
+ ('real', True, 'real/AEP (fixed)')]:
+ gb = bptt_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.eps)
+ gv = vf_param_grad(model, xbar, x, M, mode, cfg.T1, cfg.T2, cfg.eps, cfg.beta, aep)
+ g, per = cos(gv, gb, names)
+ att = " ".join(f"{k}={per[k]:+.3f}" for k in ('WQ', 'WK', 'WV', 'WO') if k in per)
+ print(f"[{label}] global={g:+.4f} attn: {att}")
+
+
+def train(cfg, model, dev):
+ tag = 'aep' if cfg.aep else 'naive'
+ opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr * 0.01)
+ trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset)
+ print(f"[real-attn EP, {tag}] params={sum(p.numel() for p in model.parameters())/1e3:.1f}K "
+ f"T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}", flush=True)
+ # stay in the stable+faithful regime: cap weight norms (Wmem for bounded-below energy,
+ # attention WV/WO/WQ/WK so the non-conservative force can't grow into the unstable s>=4 regime)
+ caps = {n: p.detach().norm().item() * 1.5 for n, p in model.named_parameters()
+ if n in ('Wmem', 'WQ', 'WK', 'WV', 'WO')}
+ cap_params = {n: p for n, p in model.named_parameters() if n in caps}
+ step, t0, best = 0, time.time(), float('inf')
+ while step < cfg.steps:
+ for x, _ in trl:
+ if step >= cfg.steps:
+ break
+ x = x.to(dev, non_blocking=True)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, dev)
+ xbar = x * (1 - M)
+ grads = vf_param_grad(model, xbar, x, M, 'real', cfg.T1, cfg.T2, cfg.eps, cfg.beta, cfg.aep)
+ opt.zero_grad(set_to_none=True)
+ bad = False
+ for p, g in zip(model.parameters(), grads):
+ if g is None or not torch.isfinite(g).all():
+ bad = True; break
+ p.grad = g
+ if bad:
+ print(f" step {step}: non-finite grad, skip", flush=True); step += 1; continue
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
+ opt.step(); sched.step()
+ with torch.no_grad(): # stay in stable+faithful regime
+ for n, p in cap_params.items():
+ pn = p.norm()
+ if pn > caps[n]:
+ p.mul_(caps[n] / pn)
+ step += 1
+ if step % cfg.log_every == 0:
+ te = evaluate(model, tel, cfg, dev, 'real', 15)
+ best = min(best, te)
+ print(f"step {step:4d}/{cfg.steps} | test masked-MSE {te:.5f} (best {best:.5f}) "
+ f"| {step/(time.time()-t0):.2f} it/s", flush=True)
+ final = evaluate(model, tel, cfg, dev, 'real', 60)
+ best = min(best, final)
+ os.makedirs(cfg.out, exist_ok=True)
+ json.dump({'tag': tag, 'final_test_masked_mse': final, 'best_test_masked_mse': best,
+ 'steps': cfg.steps}, open(os.path.join(cfg.out, f'aep_train_{tag}.json'), 'w'), indent=2)
+ print(f"[real-attn EP, {tag}] DONE final={final:.5f} best={best:.5f}", flush=True)
+
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument('--cmd', choices=['fidelity', 'train'], default='fidelity')
+ ap.add_argument('--aep', action='store_true')
+ ap.add_argument('--damp', type=float, default=0.0)
+ ap.add_argument('--dataset', default='fashionmnist')
+ ap.add_argument('--img', type=int, default=28); ap.add_argument('--ch', type=int, default=1)
+ ap.add_argument('--patch', type=int, default=7); ap.add_argument('--stride', type=int, default=7)
+ ap.add_argument('--D', type=int, default=64); ap.add_argument('--heads', type=int, default=4)
+ ap.add_argument('--dh', type=int, default=16); ap.add_argument('--mem', type=int, default=128)
+ ap.add_argument('--T1', type=int, default=100); ap.add_argument('--T2', type=int, default=15)
+ ap.add_argument('--eps', type=float, default=0.2); ap.add_argument('--beta', type=float, default=0.02)
+ ap.add_argument('--batch', type=int, default=64); ap.add_argument('--steps', type=int, default=1500)
+ ap.add_argument('--lr', type=float, default=4e-4); ap.add_argument('--wd', type=float, default=1e-4)
+ ap.add_argument('--log_every', type=int, default=100)
+ ap.add_argument('--out', default='/home/yurenh2/ept/runs')
+ cfg = ap.parse_args()
+ dev = 'cuda' if torch.cuda.is_available() else 'cpu'
+ torch.manual_seed(0)
+ model = CETReal(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh, cfg.mem).to(dev)
+ model.damp = cfg.damp
+ print('config:', vars(cfg), flush=True)
+ (train if cfg.cmd == 'train' else fidelity)(cfg, model, dev)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/cet_mvp.py b/scripts/cet_mvp.py
new file mode 100644
index 0000000..deb07d9
--- /dev/null
+++ b/scripts/cet_mvp.py
@@ -0,0 +1,372 @@
+"""
+Convergent Energy Transformer (CET) trained with Equilibrium Propagation.
+MVP reproduction of Hoier, Kerjan & Scellier, "Training a Convergent Energy
+Transformer with Equilibrium Propagation" (ICLR 2026 AM workshop).
+
+Energy terms follow the paper's Appendix B exactly:
+ E = E_enc (eq10, conv-Hopfield) boundary: masked image -> tokens
+ + E_pos (eq12, per-token bias)
+ + E_att (eq14, ET LogSumExp attention) <- attention, trained WITHOUT BP under EP
+ + E_mem (eq15, modern Hopfield memory) <- plays the role of the FFN/MLP
+ + E_dec (eq16, conv decoder) tokens <-> reconstruction y
+Tokens z are normalised (mean 0, std 1) via projection after each PGD step.
+
+Training modes:
+ ep : free phase (T1) + two nudged phases (+/-beta, T2). Parameter gradient is the
+ centered-EP estimator (1/2beta)(dE/dtheta|_{+b} - dE/dtheta|_{-b}).
+ NO backpropagation through the relaxation dynamics; attention & memory
+ weights are updated purely from the two equilibria.
+ tbpte : same model, gradient via truncated backprop through the last T2 relaxation
+ steps (the paper's BP baseline; BPTE = "backprop through equilibration").
+"""
+import argparse, os, time, json, math
+import torch, torch.nn as nn, torch.nn.functional as F
+import torchvision as tv
+from torchvision import transforms
+
+
+# --------------------------------------------------------------------------- #
+# Model
+# --------------------------------------------------------------------------- #
+def token_norm(z, eps=1e-5):
+ """Project tokens onto the constraint set C: per-token mean 0, std 1 (over D_T)."""
+ return (z - z.mean(-1, keepdim=True)) / (z.std(-1, unbiased=False, keepdim=True) + eps)
+
+
+class CET(nn.Module):
+ def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4, dh=32,
+ mem=256, gamma=0.25):
+ super().__init__()
+ self.ch, self.patch, self.stride, self.D = ch, patch, stride, D
+ self.heads, self.dh, self.gamma = heads, dh, gamma
+ gh = (img - patch) // stride + 1
+ self.gh = gh
+ self.N = gh * gh # number of tokens / patches
+
+ # Encoder (eq 10): conv kernel mapping patches -> token dim
+ self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch))
+ self.benc = nn.Parameter(torch.zeros(D))
+ # Positional bias (eq 12): per (token, dim), NOT shared across tokens
+ self.bpos = nn.Parameter(torch.zeros(self.N, D))
+ # Decoder (eq 16): conv kernel mapping reconstruction -> token dim
+ self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch))
+ self.bdec = nn.Parameter(torch.zeros(ch))
+ # Attention (eq 13/14): key/query projections, no value tensor (as in ET)
+ self.WQ = nn.Parameter(torch.empty(heads, dh, D))
+ self.WK = nn.Parameter(torch.empty(heads, dh, D))
+ # Memory (eq 15): modern Hopfield memory bank (role of the MLP)
+ self.Wmem = nn.Parameter(torch.empty(D, mem))
+
+ nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5
+ nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5
+ nn.init.normal_(self.WQ, std=1.0 / math.sqrt(D))
+ nn.init.normal_(self.WK, std=1.0 / math.sqrt(D))
+ nn.init.normal_(self.Wmem, std=1.0 / math.sqrt(D))
+
+ # -- patch <-> token conv helpers -------------------------------------- #
+ def encode(self, xbar): # (B,C,H,W) -> (B,N,D)
+ e = F.conv2d(xbar, self.Wenc, stride=self.stride)
+ return e.flatten(2).transpose(1, 2)
+
+ def decode_conv(self, y): # (B,C,H,W) -> (B,N,D)
+ d = F.conv2d(y, self.Wdec, stride=self.stride)
+ return d.flatten(2).transpose(1, 2)
+
+ # -- energy (per-sample, shape (B,)) ----------------------------------- #
+ def energy(self, xbar, z, y):
+ enc = self.encode(xbar) # (B,N,D)
+ E = 0.5 * (z ** 2).sum((1, 2)) - (enc * z).sum((1, 2)) - (z * self.benc).sum((1, 2))
+ E = E - (z * self.bpos).sum((1, 2)) # E_pos (eq12)
+ # E_att (eq14): per head, score(query m, key n) = <Q_m, K_n>; energy = -1/g sum_m lse_n
+ Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) # (B,H,N,dh)
+ K = torch.einsum('bnd,hjd->bhnj', z, self.WK)
+ A = torch.einsum('bhmj,bhnj->bhmn', Q, K) # (B,H,N,N)
+ lse = torch.logsumexp(self.gamma * A, dim=-1) # (B,H,N)
+ E = E - (1.0 / self.gamma) * lse.sum((1, 2))
+ # E_mem (eq15): -sum_token sum_mem relu(Wmem^T z)^2
+ proj = torch.einsum('bnd,dm->bnm', z, self.Wmem) # (B,N,M)
+ E = E - (F.relu(proj) ** 2).sum((1, 2))
+ # E_dec (eq16): 1/2 y^2 - <conv(y),z> - <y,bdec>
+ dc = self.decode_conv(y)
+ E = (E + 0.5 * (y ** 2).sum((1, 2, 3)) - (dc * z).sum((1, 2))
+ - (y * self.bdec[None, :, None, None]).sum((1, 2, 3)))
+ return E
+
+ def init_state(self, xbar):
+ z = token_norm(self.encode(xbar)).detach()
+ y = xbar.clone().detach()
+ return z, y
+
+ # -- one PGD step on F = E + beta*Cost --------------------------------- #
+ def _grad_step(self, xbar, z, y, eps, x=None, mask=None, beta=0.0, create_graph=False):
+ z = z.requires_grad_(True)
+ y = y.requires_grad_(True)
+ Etot = self.energy(xbar, z, y).sum()
+ if beta != 0.0:
+ Etot = Etot + beta * masked_cost(y, x, mask)
+ gz, gy = torch.autograd.grad(Etot, [z, y], create_graph=create_graph)
+ z = token_norm(z - eps * gz)
+ y = y - eps * gy
+ return z, y
+
+ @torch.no_grad()
+ def relax(self, xbar, z, y, steps, eps, x=None, mask=None, beta=0.0):
+ for _ in range(steps):
+ with torch.enable_grad():
+ z, y = self._grad_step(xbar, z, y, eps, x, mask, beta)
+ z, y = z.detach(), y.detach()
+ return z, y
+
+
+# --------------------------------------------------------------------------- #
+# Cost / masking
+# --------------------------------------------------------------------------- #
+def masked_cost(y, x, mask):
+ """0.5 * sum over masked pixels of (y-x)^2, summed over batch (energy units)."""
+ return 0.5 * (((y - x) ** 2) * mask).sum()
+
+
+def masked_mse(y, x, mask):
+ """Mean squared error over masked pixels only (reporting metric)."""
+ num = (((y - x) ** 2) * mask).sum()
+ den = mask.sum().clamp_min(1.0)
+ return (num / den).item()
+
+
+def make_patch_mask(B, gh, patch, stride, H, W, ratio, device, gen=None):
+ """Random per-sample patch mask (1 = masked/occluded). Assumes stride==patch."""
+ npatch = gh * gh
+ nmask = int(round(ratio * npatch))
+ noise = torch.rand(B, npatch, device=device, generator=gen)
+ idx = noise.argsort(dim=1)
+ pm = torch.zeros(B, npatch, device=device)
+ pm.scatter_(1, idx[:, :nmask], 1.0)
+ pm = pm.view(B, gh, gh)
+ M = pm.repeat_interleave(patch, 1).repeat_interleave(patch, 2) # (B,H,W)
+ return M.unsqueeze(1) # (B,1,H,W)
+
+
+# --------------------------------------------------------------------------- #
+# Gradient estimators
+# --------------------------------------------------------------------------- #
+def ep_param_grads(model, xbar, x, mask, T1, T2, eps, beta):
+ """Centered EP. Returns (grads list, free-phase masked MSE for monitoring)."""
+ z0, y0 = model.init_state(xbar)
+ z0, y0 = model.relax(xbar, z0, y0, T1, eps) # free phase, beta=0
+ free_mse = masked_mse(y0, x, mask)
+ zp, yp = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=+beta)
+ zm, ym = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=-beta)
+ params = [p for p in model.parameters()]
+ Ep = model.energy(xbar, zp, yp).sum()
+ gp = torch.autograd.grad(Ep, params)
+ Em = model.energy(xbar, zm, ym).sum()
+ gm = torch.autograd.grad(Em, params)
+ grads = [(a - b) / (2.0 * beta) for a, b in zip(gp, gm)]
+ return grads, free_mse
+
+
+def tbpte_loss(model, xbar, x, mask, T1, T2, eps):
+ """Free relaxation (detached) then backprop through last T2 steps. Returns loss."""
+ z, y = model.init_state(xbar)
+ z, y = model.relax(xbar, z, y, T1, eps) # detached
+ z = z.detach(); y = y.detach()
+ for _ in range(T2): # last T2 steps WITH graph
+ z, y = model._grad_step(xbar, z, y, eps, create_graph=True)
+ return masked_cost(y, x, mask) / mask.sum().clamp_min(1.0), y
+
+
+def bptt_param_grads(model, xbar, x, mask, T1, eps):
+ """Full backprop through ALL T1 relaxation steps (smoke-test reference only)."""
+ z, y = model.init_state(xbar)
+ for _ in range(T1):
+ z, y = model._grad_step(xbar, z, y, eps, create_graph=True)
+ loss = masked_cost(y, x, mask) / mask.sum().clamp_min(1.0)
+ return torch.autograd.grad(loss, [p for p in model.parameters()])
+
+
+# --------------------------------------------------------------------------- #
+# Data
+# --------------------------------------------------------------------------- #
+def get_loaders(batch, root='/tmp/cet_mvp/data', workers=4, dataset='cifar10'):
+ if dataset == 'cifar10':
+ tf = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize([0.5] * 3, [0.5] * 3)]) # -> [-1,1]
+ tr = tv.datasets.CIFAR10(root, train=True, download=True, transform=tf)
+ te = tv.datasets.CIFAR10(root, train=False, download=True, transform=tf)
+ elif dataset == 'fashionmnist':
+ tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+ tr = tv.datasets.FashionMNIST(root, train=True, download=True, transform=tf)
+ te = tv.datasets.FashionMNIST(root, train=False, download=True, transform=tf)
+ else:
+ raise ValueError(dataset)
+ trl = torch.utils.data.DataLoader(tr, batch, shuffle=True, num_workers=workers,
+ drop_last=True, pin_memory=True)
+ tel = torch.utils.data.DataLoader(te, batch, shuffle=False, num_workers=workers,
+ pin_memory=True)
+ return trl, tel
+
+
+@torch.no_grad()
+def evaluate(model, loader, cfg, device, max_batches=20):
+ model.eval()
+ tot, n = 0.0, 0
+ gen = torch.Generator(device=device).manual_seed(0)
+ for i, (x, _) in enumerate(loader):
+ if i >= max_batches:
+ break
+ x = x.to(device)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride,
+ x.size(2), x.size(3), cfg.mask_ratio, device, gen)
+ xbar = x * (1 - M)
+ z, y = model.init_state(xbar)
+ z, y = model.relax(xbar, z, y, cfg.T1, cfg.eps)
+ tot += masked_mse(y, x, M) * x.size(0); n += x.size(0)
+ model.train()
+ return tot / n
+
+
+# --------------------------------------------------------------------------- #
+# Train
+# --------------------------------------------------------------------------- #
+def train(cfg):
+ device = cfg.device
+ torch.manual_seed(cfg.seed)
+ model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh,
+ cfg.mem, cfg.gamma).to(device)
+ opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min)
+ trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset)
+ print(f"[{cfg.mode}] model params={sum(p.numel() for p in model.parameters())/1e3:.1f}K "
+ f"N_tokens={model.N} D={cfg.D} | T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}",
+ flush=True)
+
+ step, t0, run_loss = 0, time.time(), 0.0
+ while step < cfg.steps:
+ for x, _ in trl:
+ if step >= cfg.steps:
+ break
+ x = x.to(device, non_blocking=True)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride,
+ x.size(2), x.size(3), cfg.mask_ratio, device)
+ xbar = x * (1 - M)
+
+ opt.zero_grad(set_to_none=True)
+ if cfg.mode == 'ep':
+ grads, tr_mse = ep_param_grads(model, xbar, x, M, cfg.T1, cfg.T2,
+ cfg.eps, cfg.beta)
+ for p, g in zip(model.parameters(), grads):
+ p.grad = g
+ else: # tbpte
+ loss, _ = tbpte_loss(model, xbar, x, M, cfg.T1, cfg.T2, cfg.eps)
+ loss.backward()
+ tr_mse = loss.item()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip)
+ opt.step(); sched.step()
+ run_loss += tr_mse; step += 1
+
+ if step % cfg.log_every == 0:
+ avg = run_loss / cfg.log_every; run_loss = 0.0
+ sps = step / (time.time() - t0)
+ print(f"step {step:5d}/{cfg.steps} | train masked-MSE {avg:.5f} "
+ f"| lr {sched.get_last_lr()[0]:.2e} | {sps:.1f} it/s", flush=True)
+ if step % cfg.eval_every == 0 or step == cfg.steps:
+ te_mse = evaluate(model, tel, cfg, device)
+ print(f" >> [eval] step {step} test masked-MSE {te_mse:.5f}", flush=True)
+
+ final = evaluate(model, tel, cfg, device, max_batches=100)
+ os.makedirs(cfg.out, exist_ok=True)
+ res = {'mode': cfg.mode, 'final_test_masked_mse': final, 'steps': cfg.steps,
+ 'config': {k: getattr(cfg, k) for k in
+ ['T1', 'T2', 'eps', 'beta', 'D', 'heads', 'dh', 'mem',
+ 'patch', 'stride', 'mask_ratio', 'batch', 'lr']}}
+ with open(os.path.join(cfg.out, f'result_{cfg.mode}.json'), 'w') as f:
+ json.dump(res, f, indent=2)
+ torch.save(model.state_dict(), os.path.join(cfg.out, f'cet_{cfg.mode}.pt'))
+ print(f"[{cfg.mode}] DONE final test masked-MSE = {final:.5f}", flush=True)
+ return final
+
+
+# --------------------------------------------------------------------------- #
+# Smoke test
+# --------------------------------------------------------------------------- #
+def _residual(model, xbar, z, y, eps):
+ """Norm of the PGD update (proxy for ||grad E|| at the constrained equilibrium)."""
+ with torch.enable_grad():
+ zn, yn = model._grad_step(xbar, z.clone(), y.clone(), eps)
+ return ((zn - z).norm() / (z.norm() + 1e-8)).item(), ((yn - y).norm() / (y.norm() + 1e-8)).item()
+
+
+def smoke(cfg):
+ device = cfg.device
+ torch.manual_seed(0)
+ model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=32, heads=2, dh=16,
+ mem=32, gamma=cfg.gamma).to(device)
+ x = torch.randn(16, cfg.ch, cfg.img, cfg.img, device=device).clamp(-1, 1)
+ M = make_patch_mask(16, model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, device)
+ xbar = x * (1 - M)
+ T1 = cfg.T1
+ print(f"[smoke] T1={T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}")
+
+ # (a) energy decreases during relaxation
+ z, y = model.init_state(xbar)
+ print("energy trajectory (free phase):")
+ es = []
+ for t in range(T1 + 1):
+ e = model.energy(xbar, z, y).mean().item(); es.append(e)
+ if t % max(1, T1 // 6) == 0:
+ rz, ry = _residual(model, xbar, z, y, cfg.eps)
+ print(f" step {t:3d} E={e:12.4f} masked-MSE={masked_mse(y,x,M):.4f}"
+ f" rel-step |dz|={rz:.2e} |dy|={ry:.2e}")
+ with torch.enable_grad():
+ z, y = model._grad_step(xbar, z, y, cfg.eps)
+ z, y = z.detach(), y.detach()
+ mono = all(es[i+1] <= es[i] + 1e-3 for i in range(len(es)-1))
+ print(f" monotonic non-increasing: {mono} (start {es[0]:.2f} -> end {es[-1]:.2f})")
+ print(f" NaN in state: {torch.isnan(z).any().item() or torch.isnan(y).any().item()}")
+
+ # (b) EP gradient vs full-BPTT gradient (key correctness gate)
+ g_ep, _ = ep_param_grads(model, xbar, x, M, T1, cfg.T2, cfg.eps, beta=cfg.beta)
+ g_bp = bptt_param_grads(model, xbar, x, M, T1, cfg.eps)
+ fe = torch.cat([g.flatten() for g in g_ep])
+ fb = torch.cat([g.flatten() for g in g_bp])
+ cos = F.cosine_similarity(fe, fb, dim=0).item()
+ names = [n for n, _ in model.named_parameters()]
+ print(f"\nEP-vs-BPTT gradient cosine (global): {cos:.4f}")
+ for n, a, b in zip(names, g_ep, g_bp):
+ c = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ print(f" {n:6s} cos={c:+.3f} |ep|={a.norm():.3e} |bptt|={b.norm():.3e}")
+ print(f"\nSMOKE {'PASS' if (mono and cos > 0.6) else 'CHECK'} "
+ f"(want energy monotone & global cos>0.6)")
+
+
+# --------------------------------------------------------------------------- #
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--mode', choices=['ep', 'tbpte', 'smoke'], default='smoke')
+ p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10')
+ p.add_argument('--steps', type=int, default=4000)
+ p.add_argument('--batch', type=int, default=128)
+ p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3)
+ p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8)
+ p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4)
+ p.add_argument('--dh', type=int, default=32); p.add_argument('--mem', type=int, default=256)
+ p.add_argument('--gamma', type=float, default=0.25)
+ p.add_argument('--T1', type=int, default=30); p.add_argument('--T2', type=int, default=5)
+ p.add_argument('--eps', type=float, default=0.5); p.add_argument('--beta', type=float, default=0.1)
+ p.add_argument('--mask_ratio', type=float, default=0.5)
+ p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6)
+ p.add_argument('--wd', type=float, default=3e-5); p.add_argument('--clip', type=float, default=10.0)
+ p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=1000)
+ p.add_argument('--seed', type=int, default=0)
+ p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs')
+ p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ cfg = p.parse_args()
+ print('config:', vars(cfg), flush=True)
+ if cfg.mode == 'smoke':
+ smoke(cfg)
+ else:
+ train(cfg)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plot_jr_cmp.py b/scripts/plot_jr_cmp.py
new file mode 100644
index 0000000..2f8af95
--- /dev/null
+++ b/scripts/plot_jr_cmp.py
@@ -0,0 +1,20 @@
+import re, matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt
+def parse(p):
+ s,c,e=[],[],[]
+ for ln in open(p):
+ m=re.search(r'step\s+(\d+)/\d+ \| val CE ([\d.]+) ema=([\d.]+)', ln)
+ if m: s.append(int(m.group(1))); c.append(float(m.group(2))); e.append(float(m.group(3)))
+ return s,c,e
+fz=parse('ep_run/runs/ep_resreg_warm.log'); ad=parse('ep_run/runs/ep_rr_ajr.log')
+fig,ax=plt.subplots(figsize=(11,5.5),dpi=140)
+for (s,c,e),col,lab in [(fz,'#4363d8','frozen jr=0.1'),(ad,'#e6194B','adaptive jacreg (jr_max=16)')]:
+ ax.plot(s,c,color=col,lw=0.7,alpha=0.25)
+ ax.plot(s,e,color=col,lw=2.3,label=lab+f' (last ema {e[-1]:.4f})')
+ax.axhline(2.0,color='gray',ls='--',lw=1,alpha=0.7,label='val CE = 2.0')
+ax.set_xlabel('training step'); ax.set_ylabel('val CE'); ax.set_ylim(1.92,2.6)
+ax.set_title('Frozen jr vs adaptive jacreg — C512 EP, same s2000 warm-start (only jr differs)\nbold = ema, faint = raw per-log val CE', fontsize=10.5)
+ax.legend(fontsize=9.5, loc='upper right'); ax.grid(alpha=0.2)
+fig.tight_layout(); fig.savefig('frozen_vs_adaptive.png',dpi=140,bbox_inches='tight')
+print(f"frozen: {len(fz[0])} pts, steps {fz[0][0]}-{fz[0][-1]}, best ema {min(fz[2]):.4f}")
+print(f"adaptive: {len(ad[0])} pts, steps {ad[0][0]}-{ad[0][-1]}, best ema {min(ad[2]):.4f}")
+print("saved")