""" Convergent Energy Transformer (CET) trained with Equilibrium Propagation. MVP reproduction of Hoier, Kerjan & Scellier, "Training a Convergent Energy Transformer with Equilibrium Propagation" (ICLR 2026 AM workshop). Energy terms follow the paper's Appendix B exactly: E = E_enc (eq10, conv-Hopfield) boundary: masked image -> tokens + E_pos (eq12, per-token bias) + E_att (eq14, ET LogSumExp attention) <- attention, trained WITHOUT BP under EP + E_mem (eq15, modern Hopfield memory) <- plays the role of the FFN/MLP + E_dec (eq16, conv decoder) tokens <-> reconstruction y Tokens z are normalised (mean 0, std 1) via projection after each PGD step. Training modes: ep : free phase (T1) + two nudged phases (+/-beta, T2). Parameter gradient is the centered-EP estimator (1/2beta)(dE/dtheta|_{+b} - dE/dtheta|_{-b}). NO backpropagation through the relaxation dynamics; attention & memory weights are updated purely from the two equilibria. tbpte : same model, gradient via truncated backprop through the last T2 relaxation steps (the paper's BP baseline; BPTE = "backprop through equilibration"). """ import argparse, os, time, json, math import torch, torch.nn as nn, torch.nn.functional as F import torchvision as tv from torchvision import transforms # --------------------------------------------------------------------------- # # Model # --------------------------------------------------------------------------- # def token_norm(z, eps=1e-5): """Project tokens onto the constraint set C: per-token mean 0, std 1 (over D_T).""" return (z - z.mean(-1, keepdim=True)) / (z.std(-1, unbiased=False, keepdim=True) + eps) class CET(nn.Module): def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4, dh=32, mem=256, gamma=0.25): super().__init__() self.ch, self.patch, self.stride, self.D = ch, patch, stride, D self.heads, self.dh, self.gamma = heads, dh, gamma gh = (img - patch) // stride + 1 self.gh = gh self.N = gh * gh # number of tokens / patches # Encoder (eq 10): conv kernel mapping patches -> token dim self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch)) self.benc = nn.Parameter(torch.zeros(D)) # Positional bias (eq 12): per (token, dim), NOT shared across tokens self.bpos = nn.Parameter(torch.zeros(self.N, D)) # Decoder (eq 16): conv kernel mapping reconstruction -> token dim self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch)) self.bdec = nn.Parameter(torch.zeros(ch)) # Attention (eq 13/14): key/query projections, no value tensor (as in ET) self.WQ = nn.Parameter(torch.empty(heads, dh, D)) self.WK = nn.Parameter(torch.empty(heads, dh, D)) # Memory (eq 15): modern Hopfield memory bank (role of the MLP) self.Wmem = nn.Parameter(torch.empty(D, mem)) nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5 nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5 nn.init.normal_(self.WQ, std=1.0 / math.sqrt(D)) nn.init.normal_(self.WK, std=1.0 / math.sqrt(D)) nn.init.normal_(self.Wmem, std=1.0 / math.sqrt(D)) # -- patch <-> token conv helpers -------------------------------------- # def encode(self, xbar): # (B,C,H,W) -> (B,N,D) e = F.conv2d(xbar, self.Wenc, stride=self.stride) return e.flatten(2).transpose(1, 2) def decode_conv(self, y): # (B,C,H,W) -> (B,N,D) d = F.conv2d(y, self.Wdec, stride=self.stride) return d.flatten(2).transpose(1, 2) # -- energy (per-sample, shape (B,)) ----------------------------------- # def energy(self, xbar, z, y): enc = self.encode(xbar) # (B,N,D) E = 0.5 * (z ** 2).sum((1, 2)) - (enc * z).sum((1, 2)) - (z * self.benc).sum((1, 2)) E = E - (z * self.bpos).sum((1, 2)) # E_pos (eq12) # E_att (eq14): per head, score(query m, key n) = ; energy = -1/g sum_m lse_n Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) # (B,H,N,dh) K = torch.einsum('bnd,hjd->bhnj', z, self.WK) A = torch.einsum('bhmj,bhnj->bhmn', Q, K) # (B,H,N,N) lse = torch.logsumexp(self.gamma * A, dim=-1) # (B,H,N) E = E - (1.0 / self.gamma) * lse.sum((1, 2)) # E_mem (eq15): -sum_token sum_mem relu(Wmem^T z)^2 proj = torch.einsum('bnd,dm->bnm', z, self.Wmem) # (B,N,M) E = E - (F.relu(proj) ** 2).sum((1, 2)) # E_dec (eq16): 1/2 y^2 - - dc = self.decode_conv(y) E = (E + 0.5 * (y ** 2).sum((1, 2, 3)) - (dc * z).sum((1, 2)) - (y * self.bdec[None, :, None, None]).sum((1, 2, 3))) return E def init_state(self, xbar): z = token_norm(self.encode(xbar)).detach() y = xbar.clone().detach() return z, y # -- one PGD step on F = E + beta*Cost --------------------------------- # def _grad_step(self, xbar, z, y, eps, x=None, mask=None, beta=0.0, create_graph=False): z = z.requires_grad_(True) y = y.requires_grad_(True) Etot = self.energy(xbar, z, y).sum() if beta != 0.0: Etot = Etot + beta * masked_cost(y, x, mask) gz, gy = torch.autograd.grad(Etot, [z, y], create_graph=create_graph) z = token_norm(z - eps * gz) y = y - eps * gy return z, y @torch.no_grad() def relax(self, xbar, z, y, steps, eps, x=None, mask=None, beta=0.0): for _ in range(steps): with torch.enable_grad(): z, y = self._grad_step(xbar, z, y, eps, x, mask, beta) z, y = z.detach(), y.detach() return z, y # --------------------------------------------------------------------------- # # Cost / masking # --------------------------------------------------------------------------- # def masked_cost(y, x, mask): """0.5 * sum over masked pixels of (y-x)^2, summed over batch (energy units).""" return 0.5 * (((y - x) ** 2) * mask).sum() def masked_mse(y, x, mask): """Mean squared error over masked pixels only (reporting metric).""" num = (((y - x) ** 2) * mask).sum() den = mask.sum().clamp_min(1.0) return (num / den).item() def make_patch_mask(B, gh, patch, stride, H, W, ratio, device, gen=None): """Random per-sample patch mask (1 = masked/occluded). Assumes stride==patch.""" npatch = gh * gh nmask = int(round(ratio * npatch)) noise = torch.rand(B, npatch, device=device, generator=gen) idx = noise.argsort(dim=1) pm = torch.zeros(B, npatch, device=device) pm.scatter_(1, idx[:, :nmask], 1.0) pm = pm.view(B, gh, gh) M = pm.repeat_interleave(patch, 1).repeat_interleave(patch, 2) # (B,H,W) return M.unsqueeze(1) # (B,1,H,W) # --------------------------------------------------------------------------- # # Gradient estimators # --------------------------------------------------------------------------- # def ep_param_grads(model, xbar, x, mask, T1, T2, eps, beta): """Centered EP. Returns (grads list, free-phase masked MSE for monitoring).""" z0, y0 = model.init_state(xbar) z0, y0 = model.relax(xbar, z0, y0, T1, eps) # free phase, beta=0 free_mse = masked_mse(y0, x, mask) zp, yp = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=+beta) zm, ym = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=-beta) params = [p for p in model.parameters()] Ep = model.energy(xbar, zp, yp).sum() gp = torch.autograd.grad(Ep, params) Em = model.energy(xbar, zm, ym).sum() gm = torch.autograd.grad(Em, params) grads = [(a - b) / (2.0 * beta) for a, b in zip(gp, gm)] return grads, free_mse def tbpte_loss(model, xbar, x, mask, T1, T2, eps): """Free relaxation (detached) then backprop through last T2 steps. Returns loss.""" z, y = model.init_state(xbar) z, y = model.relax(xbar, z, y, T1, eps) # detached z = z.detach(); y = y.detach() for _ in range(T2): # last T2 steps WITH graph z, y = model._grad_step(xbar, z, y, eps, create_graph=True) return masked_cost(y, x, mask) / mask.sum().clamp_min(1.0), y def bptt_param_grads(model, xbar, x, mask, T1, eps): """Full backprop through ALL T1 relaxation steps (smoke-test reference only).""" z, y = model.init_state(xbar) for _ in range(T1): z, y = model._grad_step(xbar, z, y, eps, create_graph=True) loss = masked_cost(y, x, mask) / mask.sum().clamp_min(1.0) return torch.autograd.grad(loss, [p for p in model.parameters()]) # --------------------------------------------------------------------------- # # Data # --------------------------------------------------------------------------- # def get_loaders(batch, root='/tmp/cet_mvp/data', workers=4, dataset='cifar10'): if dataset == 'cifar10': tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5] * 3, [0.5] * 3)]) # -> [-1,1] tr = tv.datasets.CIFAR10(root, train=True, download=True, transform=tf) te = tv.datasets.CIFAR10(root, train=False, download=True, transform=tf) elif dataset == 'fashionmnist': tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) tr = tv.datasets.FashionMNIST(root, train=True, download=True, transform=tf) te = tv.datasets.FashionMNIST(root, train=False, download=True, transform=tf) else: raise ValueError(dataset) trl = torch.utils.data.DataLoader(tr, batch, shuffle=True, num_workers=workers, drop_last=True, pin_memory=True) tel = torch.utils.data.DataLoader(te, batch, shuffle=False, num_workers=workers, pin_memory=True) return trl, tel @torch.no_grad() def evaluate(model, loader, cfg, device, max_batches=20): model.eval() tot, n = 0.0, 0 gen = torch.Generator(device=device).manual_seed(0) for i, (x, _) in enumerate(loader): if i >= max_batches: break x = x.to(device) M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, x.size(2), x.size(3), cfg.mask_ratio, device, gen) xbar = x * (1 - M) z, y = model.init_state(xbar) z, y = model.relax(xbar, z, y, cfg.T1, cfg.eps) tot += masked_mse(y, x, M) * x.size(0); n += x.size(0) model.train() return tot / n # --------------------------------------------------------------------------- # # Train # --------------------------------------------------------------------------- # def train(cfg): device = cfg.device torch.manual_seed(cfg.seed) model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh, cfg.mem, cfg.gamma).to(device) opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min) trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) print(f"[{cfg.mode}] model params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " f"N_tokens={model.N} D={cfg.D} | T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}", flush=True) step, t0, run_loss = 0, time.time(), 0.0 while step < cfg.steps: for x, _ in trl: if step >= cfg.steps: break x = x.to(device, non_blocking=True) M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, x.size(2), x.size(3), cfg.mask_ratio, device) xbar = x * (1 - M) opt.zero_grad(set_to_none=True) if cfg.mode == 'ep': grads, tr_mse = ep_param_grads(model, xbar, x, M, cfg.T1, cfg.T2, cfg.eps, cfg.beta) for p, g in zip(model.parameters(), grads): p.grad = g else: # tbpte loss, _ = tbpte_loss(model, xbar, x, M, cfg.T1, cfg.T2, cfg.eps) loss.backward() tr_mse = loss.item() torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip) opt.step(); sched.step() run_loss += tr_mse; step += 1 if step % cfg.log_every == 0: avg = run_loss / cfg.log_every; run_loss = 0.0 sps = step / (time.time() - t0) print(f"step {step:5d}/{cfg.steps} | train masked-MSE {avg:.5f} " f"| lr {sched.get_last_lr()[0]:.2e} | {sps:.1f} it/s", flush=True) if step % cfg.eval_every == 0 or step == cfg.steps: te_mse = evaluate(model, tel, cfg, device) print(f" >> [eval] step {step} test masked-MSE {te_mse:.5f}", flush=True) final = evaluate(model, tel, cfg, device, max_batches=100) os.makedirs(cfg.out, exist_ok=True) res = {'mode': cfg.mode, 'final_test_masked_mse': final, 'steps': cfg.steps, 'config': {k: getattr(cfg, k) for k in ['T1', 'T2', 'eps', 'beta', 'D', 'heads', 'dh', 'mem', 'patch', 'stride', 'mask_ratio', 'batch', 'lr']}} with open(os.path.join(cfg.out, f'result_{cfg.mode}.json'), 'w') as f: json.dump(res, f, indent=2) torch.save(model.state_dict(), os.path.join(cfg.out, f'cet_{cfg.mode}.pt')) print(f"[{cfg.mode}] DONE final test masked-MSE = {final:.5f}", flush=True) return final # --------------------------------------------------------------------------- # # Smoke test # --------------------------------------------------------------------------- # def _residual(model, xbar, z, y, eps): """Norm of the PGD update (proxy for ||grad E|| at the constrained equilibrium).""" with torch.enable_grad(): zn, yn = model._grad_step(xbar, z.clone(), y.clone(), eps) return ((zn - z).norm() / (z.norm() + 1e-8)).item(), ((yn - y).norm() / (y.norm() + 1e-8)).item() def smoke(cfg): device = cfg.device torch.manual_seed(0) model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=32, heads=2, dh=16, mem=32, gamma=cfg.gamma).to(device) x = torch.randn(16, cfg.ch, cfg.img, cfg.img, device=device).clamp(-1, 1) M = make_patch_mask(16, model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, device) xbar = x * (1 - M) T1 = cfg.T1 print(f"[smoke] T1={T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}") # (a) energy decreases during relaxation z, y = model.init_state(xbar) print("energy trajectory (free phase):") es = [] for t in range(T1 + 1): e = model.energy(xbar, z, y).mean().item(); es.append(e) if t % max(1, T1 // 6) == 0: rz, ry = _residual(model, xbar, z, y, cfg.eps) print(f" step {t:3d} E={e:12.4f} masked-MSE={masked_mse(y,x,M):.4f}" f" rel-step |dz|={rz:.2e} |dy|={ry:.2e}") with torch.enable_grad(): z, y = model._grad_step(xbar, z, y, cfg.eps) z, y = z.detach(), y.detach() mono = all(es[i+1] <= es[i] + 1e-3 for i in range(len(es)-1)) print(f" monotonic non-increasing: {mono} (start {es[0]:.2f} -> end {es[-1]:.2f})") print(f" NaN in state: {torch.isnan(z).any().item() or torch.isnan(y).any().item()}") # (b) EP gradient vs full-BPTT gradient (key correctness gate) g_ep, _ = ep_param_grads(model, xbar, x, M, T1, cfg.T2, cfg.eps, beta=cfg.beta) g_bp = bptt_param_grads(model, xbar, x, M, T1, cfg.eps) fe = torch.cat([g.flatten() for g in g_ep]) fb = torch.cat([g.flatten() for g in g_bp]) cos = F.cosine_similarity(fe, fb, dim=0).item() names = [n for n, _ in model.named_parameters()] print(f"\nEP-vs-BPTT gradient cosine (global): {cos:.4f}") for n, a, b in zip(names, g_ep, g_bp): c = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() print(f" {n:6s} cos={c:+.3f} |ep|={a.norm():.3e} |bptt|={b.norm():.3e}") print(f"\nSMOKE {'PASS' if (mono and cos > 0.6) else 'CHECK'} " f"(want energy monotone & global cos>0.6)") # --------------------------------------------------------------------------- # def main(): p = argparse.ArgumentParser() p.add_argument('--mode', choices=['ep', 'tbpte', 'smoke'], default='smoke') p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10') p.add_argument('--steps', type=int, default=4000) p.add_argument('--batch', type=int, default=128) p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3) p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8) p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4) p.add_argument('--dh', type=int, default=32); p.add_argument('--mem', type=int, default=256) p.add_argument('--gamma', type=float, default=0.25) p.add_argument('--T1', type=int, default=30); p.add_argument('--T2', type=int, default=5) p.add_argument('--eps', type=float, default=0.5); p.add_argument('--beta', type=float, default=0.1) p.add_argument('--mask_ratio', type=float, default=0.5) p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6) p.add_argument('--wd', type=float, default=3e-5); p.add_argument('--clip', type=float, default=10.0) p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=1000) p.add_argument('--seed', type=int, default=0) p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs') p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') cfg = p.parse_args() print('config:', vars(cfg), flush=True) if cfg.mode == 'smoke': smoke(cfg) else: train(cfg) if __name__ == '__main__': main()