summaryrefslogtreecommitdiff
path: root/scripts/cet_mvp.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/cet_mvp.py')
-rw-r--r--scripts/cet_mvp.py372
1 files changed, 372 insertions, 0 deletions
diff --git a/scripts/cet_mvp.py b/scripts/cet_mvp.py
new file mode 100644
index 0000000..deb07d9
--- /dev/null
+++ b/scripts/cet_mvp.py
@@ -0,0 +1,372 @@
+"""
+Convergent Energy Transformer (CET) trained with Equilibrium Propagation.
+MVP reproduction of Hoier, Kerjan & Scellier, "Training a Convergent Energy
+Transformer with Equilibrium Propagation" (ICLR 2026 AM workshop).
+
+Energy terms follow the paper's Appendix B exactly:
+ E = E_enc (eq10, conv-Hopfield) boundary: masked image -> tokens
+ + E_pos (eq12, per-token bias)
+ + E_att (eq14, ET LogSumExp attention) <- attention, trained WITHOUT BP under EP
+ + E_mem (eq15, modern Hopfield memory) <- plays the role of the FFN/MLP
+ + E_dec (eq16, conv decoder) tokens <-> reconstruction y
+Tokens z are normalised (mean 0, std 1) via projection after each PGD step.
+
+Training modes:
+ ep : free phase (T1) + two nudged phases (+/-beta, T2). Parameter gradient is the
+ centered-EP estimator (1/2beta)(dE/dtheta|_{+b} - dE/dtheta|_{-b}).
+ NO backpropagation through the relaxation dynamics; attention & memory
+ weights are updated purely from the two equilibria.
+ tbpte : same model, gradient via truncated backprop through the last T2 relaxation
+ steps (the paper's BP baseline; BPTE = "backprop through equilibration").
+"""
+import argparse, os, time, json, math
+import torch, torch.nn as nn, torch.nn.functional as F
+import torchvision as tv
+from torchvision import transforms
+
+
+# --------------------------------------------------------------------------- #
+# Model
+# --------------------------------------------------------------------------- #
+def token_norm(z, eps=1e-5):
+ """Project tokens onto the constraint set C: per-token mean 0, std 1 (over D_T)."""
+ return (z - z.mean(-1, keepdim=True)) / (z.std(-1, unbiased=False, keepdim=True) + eps)
+
+
+class CET(nn.Module):
+ def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4, dh=32,
+ mem=256, gamma=0.25):
+ super().__init__()
+ self.ch, self.patch, self.stride, self.D = ch, patch, stride, D
+ self.heads, self.dh, self.gamma = heads, dh, gamma
+ gh = (img - patch) // stride + 1
+ self.gh = gh
+ self.N = gh * gh # number of tokens / patches
+
+ # Encoder (eq 10): conv kernel mapping patches -> token dim
+ self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch))
+ self.benc = nn.Parameter(torch.zeros(D))
+ # Positional bias (eq 12): per (token, dim), NOT shared across tokens
+ self.bpos = nn.Parameter(torch.zeros(self.N, D))
+ # Decoder (eq 16): conv kernel mapping reconstruction -> token dim
+ self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch))
+ self.bdec = nn.Parameter(torch.zeros(ch))
+ # Attention (eq 13/14): key/query projections, no value tensor (as in ET)
+ self.WQ = nn.Parameter(torch.empty(heads, dh, D))
+ self.WK = nn.Parameter(torch.empty(heads, dh, D))
+ # Memory (eq 15): modern Hopfield memory bank (role of the MLP)
+ self.Wmem = nn.Parameter(torch.empty(D, mem))
+
+ nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5
+ nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5
+ nn.init.normal_(self.WQ, std=1.0 / math.sqrt(D))
+ nn.init.normal_(self.WK, std=1.0 / math.sqrt(D))
+ nn.init.normal_(self.Wmem, std=1.0 / math.sqrt(D))
+
+ # -- patch <-> token conv helpers -------------------------------------- #
+ def encode(self, xbar): # (B,C,H,W) -> (B,N,D)
+ e = F.conv2d(xbar, self.Wenc, stride=self.stride)
+ return e.flatten(2).transpose(1, 2)
+
+ def decode_conv(self, y): # (B,C,H,W) -> (B,N,D)
+ d = F.conv2d(y, self.Wdec, stride=self.stride)
+ return d.flatten(2).transpose(1, 2)
+
+ # -- energy (per-sample, shape (B,)) ----------------------------------- #
+ def energy(self, xbar, z, y):
+ enc = self.encode(xbar) # (B,N,D)
+ E = 0.5 * (z ** 2).sum((1, 2)) - (enc * z).sum((1, 2)) - (z * self.benc).sum((1, 2))
+ E = E - (z * self.bpos).sum((1, 2)) # E_pos (eq12)
+ # E_att (eq14): per head, score(query m, key n) = <Q_m, K_n>; energy = -1/g sum_m lse_n
+ Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) # (B,H,N,dh)
+ K = torch.einsum('bnd,hjd->bhnj', z, self.WK)
+ A = torch.einsum('bhmj,bhnj->bhmn', Q, K) # (B,H,N,N)
+ lse = torch.logsumexp(self.gamma * A, dim=-1) # (B,H,N)
+ E = E - (1.0 / self.gamma) * lse.sum((1, 2))
+ # E_mem (eq15): -sum_token sum_mem relu(Wmem^T z)^2
+ proj = torch.einsum('bnd,dm->bnm', z, self.Wmem) # (B,N,M)
+ E = E - (F.relu(proj) ** 2).sum((1, 2))
+ # E_dec (eq16): 1/2 y^2 - <conv(y),z> - <y,bdec>
+ dc = self.decode_conv(y)
+ E = (E + 0.5 * (y ** 2).sum((1, 2, 3)) - (dc * z).sum((1, 2))
+ - (y * self.bdec[None, :, None, None]).sum((1, 2, 3)))
+ return E
+
+ def init_state(self, xbar):
+ z = token_norm(self.encode(xbar)).detach()
+ y = xbar.clone().detach()
+ return z, y
+
+ # -- one PGD step on F = E + beta*Cost --------------------------------- #
+ def _grad_step(self, xbar, z, y, eps, x=None, mask=None, beta=0.0, create_graph=False):
+ z = z.requires_grad_(True)
+ y = y.requires_grad_(True)
+ Etot = self.energy(xbar, z, y).sum()
+ if beta != 0.0:
+ Etot = Etot + beta * masked_cost(y, x, mask)
+ gz, gy = torch.autograd.grad(Etot, [z, y], create_graph=create_graph)
+ z = token_norm(z - eps * gz)
+ y = y - eps * gy
+ return z, y
+
+ @torch.no_grad()
+ def relax(self, xbar, z, y, steps, eps, x=None, mask=None, beta=0.0):
+ for _ in range(steps):
+ with torch.enable_grad():
+ z, y = self._grad_step(xbar, z, y, eps, x, mask, beta)
+ z, y = z.detach(), y.detach()
+ return z, y
+
+
+# --------------------------------------------------------------------------- #
+# Cost / masking
+# --------------------------------------------------------------------------- #
+def masked_cost(y, x, mask):
+ """0.5 * sum over masked pixels of (y-x)^2, summed over batch (energy units)."""
+ return 0.5 * (((y - x) ** 2) * mask).sum()
+
+
+def masked_mse(y, x, mask):
+ """Mean squared error over masked pixels only (reporting metric)."""
+ num = (((y - x) ** 2) * mask).sum()
+ den = mask.sum().clamp_min(1.0)
+ return (num / den).item()
+
+
+def make_patch_mask(B, gh, patch, stride, H, W, ratio, device, gen=None):
+ """Random per-sample patch mask (1 = masked/occluded). Assumes stride==patch."""
+ npatch = gh * gh
+ nmask = int(round(ratio * npatch))
+ noise = torch.rand(B, npatch, device=device, generator=gen)
+ idx = noise.argsort(dim=1)
+ pm = torch.zeros(B, npatch, device=device)
+ pm.scatter_(1, idx[:, :nmask], 1.0)
+ pm = pm.view(B, gh, gh)
+ M = pm.repeat_interleave(patch, 1).repeat_interleave(patch, 2) # (B,H,W)
+ return M.unsqueeze(1) # (B,1,H,W)
+
+
+# --------------------------------------------------------------------------- #
+# Gradient estimators
+# --------------------------------------------------------------------------- #
+def ep_param_grads(model, xbar, x, mask, T1, T2, eps, beta):
+ """Centered EP. Returns (grads list, free-phase masked MSE for monitoring)."""
+ z0, y0 = model.init_state(xbar)
+ z0, y0 = model.relax(xbar, z0, y0, T1, eps) # free phase, beta=0
+ free_mse = masked_mse(y0, x, mask)
+ zp, yp = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=+beta)
+ zm, ym = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=-beta)
+ params = [p for p in model.parameters()]
+ Ep = model.energy(xbar, zp, yp).sum()
+ gp = torch.autograd.grad(Ep, params)
+ Em = model.energy(xbar, zm, ym).sum()
+ gm = torch.autograd.grad(Em, params)
+ grads = [(a - b) / (2.0 * beta) for a, b in zip(gp, gm)]
+ return grads, free_mse
+
+
+def tbpte_loss(model, xbar, x, mask, T1, T2, eps):
+ """Free relaxation (detached) then backprop through last T2 steps. Returns loss."""
+ z, y = model.init_state(xbar)
+ z, y = model.relax(xbar, z, y, T1, eps) # detached
+ z = z.detach(); y = y.detach()
+ for _ in range(T2): # last T2 steps WITH graph
+ z, y = model._grad_step(xbar, z, y, eps, create_graph=True)
+ return masked_cost(y, x, mask) / mask.sum().clamp_min(1.0), y
+
+
+def bptt_param_grads(model, xbar, x, mask, T1, eps):
+ """Full backprop through ALL T1 relaxation steps (smoke-test reference only)."""
+ z, y = model.init_state(xbar)
+ for _ in range(T1):
+ z, y = model._grad_step(xbar, z, y, eps, create_graph=True)
+ loss = masked_cost(y, x, mask) / mask.sum().clamp_min(1.0)
+ return torch.autograd.grad(loss, [p for p in model.parameters()])
+
+
+# --------------------------------------------------------------------------- #
+# Data
+# --------------------------------------------------------------------------- #
+def get_loaders(batch, root='/tmp/cet_mvp/data', workers=4, dataset='cifar10'):
+ if dataset == 'cifar10':
+ tf = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize([0.5] * 3, [0.5] * 3)]) # -> [-1,1]
+ tr = tv.datasets.CIFAR10(root, train=True, download=True, transform=tf)
+ te = tv.datasets.CIFAR10(root, train=False, download=True, transform=tf)
+ elif dataset == 'fashionmnist':
+ tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+ tr = tv.datasets.FashionMNIST(root, train=True, download=True, transform=tf)
+ te = tv.datasets.FashionMNIST(root, train=False, download=True, transform=tf)
+ else:
+ raise ValueError(dataset)
+ trl = torch.utils.data.DataLoader(tr, batch, shuffle=True, num_workers=workers,
+ drop_last=True, pin_memory=True)
+ tel = torch.utils.data.DataLoader(te, batch, shuffle=False, num_workers=workers,
+ pin_memory=True)
+ return trl, tel
+
+
+@torch.no_grad()
+def evaluate(model, loader, cfg, device, max_batches=20):
+ model.eval()
+ tot, n = 0.0, 0
+ gen = torch.Generator(device=device).manual_seed(0)
+ for i, (x, _) in enumerate(loader):
+ if i >= max_batches:
+ break
+ x = x.to(device)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride,
+ x.size(2), x.size(3), cfg.mask_ratio, device, gen)
+ xbar = x * (1 - M)
+ z, y = model.init_state(xbar)
+ z, y = model.relax(xbar, z, y, cfg.T1, cfg.eps)
+ tot += masked_mse(y, x, M) * x.size(0); n += x.size(0)
+ model.train()
+ return tot / n
+
+
+# --------------------------------------------------------------------------- #
+# Train
+# --------------------------------------------------------------------------- #
+def train(cfg):
+ device = cfg.device
+ torch.manual_seed(cfg.seed)
+ model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh,
+ cfg.mem, cfg.gamma).to(device)
+ opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min)
+ trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset)
+ print(f"[{cfg.mode}] model params={sum(p.numel() for p in model.parameters())/1e3:.1f}K "
+ f"N_tokens={model.N} D={cfg.D} | T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}",
+ flush=True)
+
+ step, t0, run_loss = 0, time.time(), 0.0
+ while step < cfg.steps:
+ for x, _ in trl:
+ if step >= cfg.steps:
+ break
+ x = x.to(device, non_blocking=True)
+ M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride,
+ x.size(2), x.size(3), cfg.mask_ratio, device)
+ xbar = x * (1 - M)
+
+ opt.zero_grad(set_to_none=True)
+ if cfg.mode == 'ep':
+ grads, tr_mse = ep_param_grads(model, xbar, x, M, cfg.T1, cfg.T2,
+ cfg.eps, cfg.beta)
+ for p, g in zip(model.parameters(), grads):
+ p.grad = g
+ else: # tbpte
+ loss, _ = tbpte_loss(model, xbar, x, M, cfg.T1, cfg.T2, cfg.eps)
+ loss.backward()
+ tr_mse = loss.item()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip)
+ opt.step(); sched.step()
+ run_loss += tr_mse; step += 1
+
+ if step % cfg.log_every == 0:
+ avg = run_loss / cfg.log_every; run_loss = 0.0
+ sps = step / (time.time() - t0)
+ print(f"step {step:5d}/{cfg.steps} | train masked-MSE {avg:.5f} "
+ f"| lr {sched.get_last_lr()[0]:.2e} | {sps:.1f} it/s", flush=True)
+ if step % cfg.eval_every == 0 or step == cfg.steps:
+ te_mse = evaluate(model, tel, cfg, device)
+ print(f" >> [eval] step {step} test masked-MSE {te_mse:.5f}", flush=True)
+
+ final = evaluate(model, tel, cfg, device, max_batches=100)
+ os.makedirs(cfg.out, exist_ok=True)
+ res = {'mode': cfg.mode, 'final_test_masked_mse': final, 'steps': cfg.steps,
+ 'config': {k: getattr(cfg, k) for k in
+ ['T1', 'T2', 'eps', 'beta', 'D', 'heads', 'dh', 'mem',
+ 'patch', 'stride', 'mask_ratio', 'batch', 'lr']}}
+ with open(os.path.join(cfg.out, f'result_{cfg.mode}.json'), 'w') as f:
+ json.dump(res, f, indent=2)
+ torch.save(model.state_dict(), os.path.join(cfg.out, f'cet_{cfg.mode}.pt'))
+ print(f"[{cfg.mode}] DONE final test masked-MSE = {final:.5f}", flush=True)
+ return final
+
+
+# --------------------------------------------------------------------------- #
+# Smoke test
+# --------------------------------------------------------------------------- #
+def _residual(model, xbar, z, y, eps):
+ """Norm of the PGD update (proxy for ||grad E|| at the constrained equilibrium)."""
+ with torch.enable_grad():
+ zn, yn = model._grad_step(xbar, z.clone(), y.clone(), eps)
+ return ((zn - z).norm() / (z.norm() + 1e-8)).item(), ((yn - y).norm() / (y.norm() + 1e-8)).item()
+
+
+def smoke(cfg):
+ device = cfg.device
+ torch.manual_seed(0)
+ model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=32, heads=2, dh=16,
+ mem=32, gamma=cfg.gamma).to(device)
+ x = torch.randn(16, cfg.ch, cfg.img, cfg.img, device=device).clamp(-1, 1)
+ M = make_patch_mask(16, model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, device)
+ xbar = x * (1 - M)
+ T1 = cfg.T1
+ print(f"[smoke] T1={T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}")
+
+ # (a) energy decreases during relaxation
+ z, y = model.init_state(xbar)
+ print("energy trajectory (free phase):")
+ es = []
+ for t in range(T1 + 1):
+ e = model.energy(xbar, z, y).mean().item(); es.append(e)
+ if t % max(1, T1 // 6) == 0:
+ rz, ry = _residual(model, xbar, z, y, cfg.eps)
+ print(f" step {t:3d} E={e:12.4f} masked-MSE={masked_mse(y,x,M):.4f}"
+ f" rel-step |dz|={rz:.2e} |dy|={ry:.2e}")
+ with torch.enable_grad():
+ z, y = model._grad_step(xbar, z, y, cfg.eps)
+ z, y = z.detach(), y.detach()
+ mono = all(es[i+1] <= es[i] + 1e-3 for i in range(len(es)-1))
+ print(f" monotonic non-increasing: {mono} (start {es[0]:.2f} -> end {es[-1]:.2f})")
+ print(f" NaN in state: {torch.isnan(z).any().item() or torch.isnan(y).any().item()}")
+
+ # (b) EP gradient vs full-BPTT gradient (key correctness gate)
+ g_ep, _ = ep_param_grads(model, xbar, x, M, T1, cfg.T2, cfg.eps, beta=cfg.beta)
+ g_bp = bptt_param_grads(model, xbar, x, M, T1, cfg.eps)
+ fe = torch.cat([g.flatten() for g in g_ep])
+ fb = torch.cat([g.flatten() for g in g_bp])
+ cos = F.cosine_similarity(fe, fb, dim=0).item()
+ names = [n for n, _ in model.named_parameters()]
+ print(f"\nEP-vs-BPTT gradient cosine (global): {cos:.4f}")
+ for n, a, b in zip(names, g_ep, g_bp):
+ c = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ print(f" {n:6s} cos={c:+.3f} |ep|={a.norm():.3e} |bptt|={b.norm():.3e}")
+ print(f"\nSMOKE {'PASS' if (mono and cos > 0.6) else 'CHECK'} "
+ f"(want energy monotone & global cos>0.6)")
+
+
+# --------------------------------------------------------------------------- #
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--mode', choices=['ep', 'tbpte', 'smoke'], default='smoke')
+ p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10')
+ p.add_argument('--steps', type=int, default=4000)
+ p.add_argument('--batch', type=int, default=128)
+ p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3)
+ p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8)
+ p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4)
+ p.add_argument('--dh', type=int, default=32); p.add_argument('--mem', type=int, default=256)
+ p.add_argument('--gamma', type=float, default=0.25)
+ p.add_argument('--T1', type=int, default=30); p.add_argument('--T2', type=int, default=5)
+ p.add_argument('--eps', type=float, default=0.5); p.add_argument('--beta', type=float, default=0.1)
+ p.add_argument('--mask_ratio', type=float, default=0.5)
+ p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6)
+ p.add_argument('--wd', type=float, default=3e-5); p.add_argument('--clip', type=float, default=10.0)
+ p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=1000)
+ p.add_argument('--seed', type=int, default=0)
+ p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs')
+ p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ cfg = p.parse_args()
+ print('config:', vars(cfg), flush=True)
+ if cfg.mode == 'smoke':
+ smoke(cfg)
+ else:
+ train(cfg)
+
+
+if __name__ == '__main__':
+ main()