diff options
Diffstat (limited to 'scripts/cet_mvp.py')
| -rw-r--r-- | scripts/cet_mvp.py | 372 |
1 files changed, 372 insertions, 0 deletions
diff --git a/scripts/cet_mvp.py b/scripts/cet_mvp.py new file mode 100644 index 0000000..deb07d9 --- /dev/null +++ b/scripts/cet_mvp.py @@ -0,0 +1,372 @@ +""" +Convergent Energy Transformer (CET) trained with Equilibrium Propagation. +MVP reproduction of Hoier, Kerjan & Scellier, "Training a Convergent Energy +Transformer with Equilibrium Propagation" (ICLR 2026 AM workshop). + +Energy terms follow the paper's Appendix B exactly: + E = E_enc (eq10, conv-Hopfield) boundary: masked image -> tokens + + E_pos (eq12, per-token bias) + + E_att (eq14, ET LogSumExp attention) <- attention, trained WITHOUT BP under EP + + E_mem (eq15, modern Hopfield memory) <- plays the role of the FFN/MLP + + E_dec (eq16, conv decoder) tokens <-> reconstruction y +Tokens z are normalised (mean 0, std 1) via projection after each PGD step. + +Training modes: + ep : free phase (T1) + two nudged phases (+/-beta, T2). Parameter gradient is the + centered-EP estimator (1/2beta)(dE/dtheta|_{+b} - dE/dtheta|_{-b}). + NO backpropagation through the relaxation dynamics; attention & memory + weights are updated purely from the two equilibria. + tbpte : same model, gradient via truncated backprop through the last T2 relaxation + steps (the paper's BP baseline; BPTE = "backprop through equilibration"). +""" +import argparse, os, time, json, math +import torch, torch.nn as nn, torch.nn.functional as F +import torchvision as tv +from torchvision import transforms + + +# --------------------------------------------------------------------------- # +# Model +# --------------------------------------------------------------------------- # +def token_norm(z, eps=1e-5): + """Project tokens onto the constraint set C: per-token mean 0, std 1 (over D_T).""" + return (z - z.mean(-1, keepdim=True)) / (z.std(-1, unbiased=False, keepdim=True) + eps) + + +class CET(nn.Module): + def __init__(self, img=32, ch=3, patch=8, stride=8, D=128, heads=4, dh=32, + mem=256, gamma=0.25): + super().__init__() + self.ch, self.patch, self.stride, self.D = ch, patch, stride, D + self.heads, self.dh, self.gamma = heads, dh, gamma + gh = (img - patch) // stride + 1 + self.gh = gh + self.N = gh * gh # number of tokens / patches + + # Encoder (eq 10): conv kernel mapping patches -> token dim + self.Wenc = nn.Parameter(torch.empty(D, ch, patch, patch)) + self.benc = nn.Parameter(torch.zeros(D)) + # Positional bias (eq 12): per (token, dim), NOT shared across tokens + self.bpos = nn.Parameter(torch.zeros(self.N, D)) + # Decoder (eq 16): conv kernel mapping reconstruction -> token dim + self.Wdec = nn.Parameter(torch.empty(D, ch, patch, patch)) + self.bdec = nn.Parameter(torch.zeros(ch)) + # Attention (eq 13/14): key/query projections, no value tensor (as in ET) + self.WQ = nn.Parameter(torch.empty(heads, dh, D)) + self.WK = nn.Parameter(torch.empty(heads, dh, D)) + # Memory (eq 15): modern Hopfield memory bank (role of the MLP) + self.Wmem = nn.Parameter(torch.empty(D, mem)) + + nn.init.kaiming_normal_(self.Wenc); self.Wenc.data *= 0.5 + nn.init.kaiming_normal_(self.Wdec); self.Wdec.data *= 0.5 + nn.init.normal_(self.WQ, std=1.0 / math.sqrt(D)) + nn.init.normal_(self.WK, std=1.0 / math.sqrt(D)) + nn.init.normal_(self.Wmem, std=1.0 / math.sqrt(D)) + + # -- patch <-> token conv helpers -------------------------------------- # + def encode(self, xbar): # (B,C,H,W) -> (B,N,D) + e = F.conv2d(xbar, self.Wenc, stride=self.stride) + return e.flatten(2).transpose(1, 2) + + def decode_conv(self, y): # (B,C,H,W) -> (B,N,D) + d = F.conv2d(y, self.Wdec, stride=self.stride) + return d.flatten(2).transpose(1, 2) + + # -- energy (per-sample, shape (B,)) ----------------------------------- # + def energy(self, xbar, z, y): + enc = self.encode(xbar) # (B,N,D) + E = 0.5 * (z ** 2).sum((1, 2)) - (enc * z).sum((1, 2)) - (z * self.benc).sum((1, 2)) + E = E - (z * self.bpos).sum((1, 2)) # E_pos (eq12) + # E_att (eq14): per head, score(query m, key n) = <Q_m, K_n>; energy = -1/g sum_m lse_n + Q = torch.einsum('bnd,hjd->bhnj', z, self.WQ) # (B,H,N,dh) + K = torch.einsum('bnd,hjd->bhnj', z, self.WK) + A = torch.einsum('bhmj,bhnj->bhmn', Q, K) # (B,H,N,N) + lse = torch.logsumexp(self.gamma * A, dim=-1) # (B,H,N) + E = E - (1.0 / self.gamma) * lse.sum((1, 2)) + # E_mem (eq15): -sum_token sum_mem relu(Wmem^T z)^2 + proj = torch.einsum('bnd,dm->bnm', z, self.Wmem) # (B,N,M) + E = E - (F.relu(proj) ** 2).sum((1, 2)) + # E_dec (eq16): 1/2 y^2 - <conv(y),z> - <y,bdec> + dc = self.decode_conv(y) + E = (E + 0.5 * (y ** 2).sum((1, 2, 3)) - (dc * z).sum((1, 2)) + - (y * self.bdec[None, :, None, None]).sum((1, 2, 3))) + return E + + def init_state(self, xbar): + z = token_norm(self.encode(xbar)).detach() + y = xbar.clone().detach() + return z, y + + # -- one PGD step on F = E + beta*Cost --------------------------------- # + def _grad_step(self, xbar, z, y, eps, x=None, mask=None, beta=0.0, create_graph=False): + z = z.requires_grad_(True) + y = y.requires_grad_(True) + Etot = self.energy(xbar, z, y).sum() + if beta != 0.0: + Etot = Etot + beta * masked_cost(y, x, mask) + gz, gy = torch.autograd.grad(Etot, [z, y], create_graph=create_graph) + z = token_norm(z - eps * gz) + y = y - eps * gy + return z, y + + @torch.no_grad() + def relax(self, xbar, z, y, steps, eps, x=None, mask=None, beta=0.0): + for _ in range(steps): + with torch.enable_grad(): + z, y = self._grad_step(xbar, z, y, eps, x, mask, beta) + z, y = z.detach(), y.detach() + return z, y + + +# --------------------------------------------------------------------------- # +# Cost / masking +# --------------------------------------------------------------------------- # +def masked_cost(y, x, mask): + """0.5 * sum over masked pixels of (y-x)^2, summed over batch (energy units).""" + return 0.5 * (((y - x) ** 2) * mask).sum() + + +def masked_mse(y, x, mask): + """Mean squared error over masked pixels only (reporting metric).""" + num = (((y - x) ** 2) * mask).sum() + den = mask.sum().clamp_min(1.0) + return (num / den).item() + + +def make_patch_mask(B, gh, patch, stride, H, W, ratio, device, gen=None): + """Random per-sample patch mask (1 = masked/occluded). Assumes stride==patch.""" + npatch = gh * gh + nmask = int(round(ratio * npatch)) + noise = torch.rand(B, npatch, device=device, generator=gen) + idx = noise.argsort(dim=1) + pm = torch.zeros(B, npatch, device=device) + pm.scatter_(1, idx[:, :nmask], 1.0) + pm = pm.view(B, gh, gh) + M = pm.repeat_interleave(patch, 1).repeat_interleave(patch, 2) # (B,H,W) + return M.unsqueeze(1) # (B,1,H,W) + + +# --------------------------------------------------------------------------- # +# Gradient estimators +# --------------------------------------------------------------------------- # +def ep_param_grads(model, xbar, x, mask, T1, T2, eps, beta): + """Centered EP. Returns (grads list, free-phase masked MSE for monitoring).""" + z0, y0 = model.init_state(xbar) + z0, y0 = model.relax(xbar, z0, y0, T1, eps) # free phase, beta=0 + free_mse = masked_mse(y0, x, mask) + zp, yp = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=+beta) + zm, ym = model.relax(xbar, z0.clone(), y0.clone(), T2, eps, x, mask, beta=-beta) + params = [p for p in model.parameters()] + Ep = model.energy(xbar, zp, yp).sum() + gp = torch.autograd.grad(Ep, params) + Em = model.energy(xbar, zm, ym).sum() + gm = torch.autograd.grad(Em, params) + grads = [(a - b) / (2.0 * beta) for a, b in zip(gp, gm)] + return grads, free_mse + + +def tbpte_loss(model, xbar, x, mask, T1, T2, eps): + """Free relaxation (detached) then backprop through last T2 steps. Returns loss.""" + z, y = model.init_state(xbar) + z, y = model.relax(xbar, z, y, T1, eps) # detached + z = z.detach(); y = y.detach() + for _ in range(T2): # last T2 steps WITH graph + z, y = model._grad_step(xbar, z, y, eps, create_graph=True) + return masked_cost(y, x, mask) / mask.sum().clamp_min(1.0), y + + +def bptt_param_grads(model, xbar, x, mask, T1, eps): + """Full backprop through ALL T1 relaxation steps (smoke-test reference only).""" + z, y = model.init_state(xbar) + for _ in range(T1): + z, y = model._grad_step(xbar, z, y, eps, create_graph=True) + loss = masked_cost(y, x, mask) / mask.sum().clamp_min(1.0) + return torch.autograd.grad(loss, [p for p in model.parameters()]) + + +# --------------------------------------------------------------------------- # +# Data +# --------------------------------------------------------------------------- # +def get_loaders(batch, root='/tmp/cet_mvp/data', workers=4, dataset='cifar10'): + if dataset == 'cifar10': + tf = transforms.Compose([transforms.ToTensor(), + transforms.Normalize([0.5] * 3, [0.5] * 3)]) # -> [-1,1] + tr = tv.datasets.CIFAR10(root, train=True, download=True, transform=tf) + te = tv.datasets.CIFAR10(root, train=False, download=True, transform=tf) + elif dataset == 'fashionmnist': + tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) + tr = tv.datasets.FashionMNIST(root, train=True, download=True, transform=tf) + te = tv.datasets.FashionMNIST(root, train=False, download=True, transform=tf) + else: + raise ValueError(dataset) + trl = torch.utils.data.DataLoader(tr, batch, shuffle=True, num_workers=workers, + drop_last=True, pin_memory=True) + tel = torch.utils.data.DataLoader(te, batch, shuffle=False, num_workers=workers, + pin_memory=True) + return trl, tel + + +@torch.no_grad() +def evaluate(model, loader, cfg, device, max_batches=20): + model.eval() + tot, n = 0.0, 0 + gen = torch.Generator(device=device).manual_seed(0) + for i, (x, _) in enumerate(loader): + if i >= max_batches: + break + x = x.to(device) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, + x.size(2), x.size(3), cfg.mask_ratio, device, gen) + xbar = x * (1 - M) + z, y = model.init_state(xbar) + z, y = model.relax(xbar, z, y, cfg.T1, cfg.eps) + tot += masked_mse(y, x, M) * x.size(0); n += x.size(0) + model.train() + return tot / n + + +# --------------------------------------------------------------------------- # +# Train +# --------------------------------------------------------------------------- # +def train(cfg): + device = cfg.device + torch.manual_seed(cfg.seed) + model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, cfg.D, cfg.heads, cfg.dh, + cfg.mem, cfg.gamma).to(device) + opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) + sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, cfg.steps, eta_min=cfg.lr_min) + trl, tel = get_loaders(cfg.batch, dataset=cfg.dataset) + print(f"[{cfg.mode}] model params={sum(p.numel() for p in model.parameters())/1e3:.1f}K " + f"N_tokens={model.N} D={cfg.D} | T1={cfg.T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}", + flush=True) + + step, t0, run_loss = 0, time.time(), 0.0 + while step < cfg.steps: + for x, _ in trl: + if step >= cfg.steps: + break + x = x.to(device, non_blocking=True) + M = make_patch_mask(x.size(0), model.gh, cfg.patch, cfg.stride, + x.size(2), x.size(3), cfg.mask_ratio, device) + xbar = x * (1 - M) + + opt.zero_grad(set_to_none=True) + if cfg.mode == 'ep': + grads, tr_mse = ep_param_grads(model, xbar, x, M, cfg.T1, cfg.T2, + cfg.eps, cfg.beta) + for p, g in zip(model.parameters(), grads): + p.grad = g + else: # tbpte + loss, _ = tbpte_loss(model, xbar, x, M, cfg.T1, cfg.T2, cfg.eps) + loss.backward() + tr_mse = loss.item() + torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip) + opt.step(); sched.step() + run_loss += tr_mse; step += 1 + + if step % cfg.log_every == 0: + avg = run_loss / cfg.log_every; run_loss = 0.0 + sps = step / (time.time() - t0) + print(f"step {step:5d}/{cfg.steps} | train masked-MSE {avg:.5f} " + f"| lr {sched.get_last_lr()[0]:.2e} | {sps:.1f} it/s", flush=True) + if step % cfg.eval_every == 0 or step == cfg.steps: + te_mse = evaluate(model, tel, cfg, device) + print(f" >> [eval] step {step} test masked-MSE {te_mse:.5f}", flush=True) + + final = evaluate(model, tel, cfg, device, max_batches=100) + os.makedirs(cfg.out, exist_ok=True) + res = {'mode': cfg.mode, 'final_test_masked_mse': final, 'steps': cfg.steps, + 'config': {k: getattr(cfg, k) for k in + ['T1', 'T2', 'eps', 'beta', 'D', 'heads', 'dh', 'mem', + 'patch', 'stride', 'mask_ratio', 'batch', 'lr']}} + with open(os.path.join(cfg.out, f'result_{cfg.mode}.json'), 'w') as f: + json.dump(res, f, indent=2) + torch.save(model.state_dict(), os.path.join(cfg.out, f'cet_{cfg.mode}.pt')) + print(f"[{cfg.mode}] DONE final test masked-MSE = {final:.5f}", flush=True) + return final + + +# --------------------------------------------------------------------------- # +# Smoke test +# --------------------------------------------------------------------------- # +def _residual(model, xbar, z, y, eps): + """Norm of the PGD update (proxy for ||grad E|| at the constrained equilibrium).""" + with torch.enable_grad(): + zn, yn = model._grad_step(xbar, z.clone(), y.clone(), eps) + return ((zn - z).norm() / (z.norm() + 1e-8)).item(), ((yn - y).norm() / (y.norm() + 1e-8)).item() + + +def smoke(cfg): + device = cfg.device + torch.manual_seed(0) + model = CET(cfg.img, cfg.ch, cfg.patch, cfg.stride, D=32, heads=2, dh=16, + mem=32, gamma=cfg.gamma).to(device) + x = torch.randn(16, cfg.ch, cfg.img, cfg.img, device=device).clamp(-1, 1) + M = make_patch_mask(16, model.gh, cfg.patch, cfg.stride, cfg.img, cfg.img, 0.5, device) + xbar = x * (1 - M) + T1 = cfg.T1 + print(f"[smoke] T1={T1} T2={cfg.T2} eps={cfg.eps} beta={cfg.beta}") + + # (a) energy decreases during relaxation + z, y = model.init_state(xbar) + print("energy trajectory (free phase):") + es = [] + for t in range(T1 + 1): + e = model.energy(xbar, z, y).mean().item(); es.append(e) + if t % max(1, T1 // 6) == 0: + rz, ry = _residual(model, xbar, z, y, cfg.eps) + print(f" step {t:3d} E={e:12.4f} masked-MSE={masked_mse(y,x,M):.4f}" + f" rel-step |dz|={rz:.2e} |dy|={ry:.2e}") + with torch.enable_grad(): + z, y = model._grad_step(xbar, z, y, cfg.eps) + z, y = z.detach(), y.detach() + mono = all(es[i+1] <= es[i] + 1e-3 for i in range(len(es)-1)) + print(f" monotonic non-increasing: {mono} (start {es[0]:.2f} -> end {es[-1]:.2f})") + print(f" NaN in state: {torch.isnan(z).any().item() or torch.isnan(y).any().item()}") + + # (b) EP gradient vs full-BPTT gradient (key correctness gate) + g_ep, _ = ep_param_grads(model, xbar, x, M, T1, cfg.T2, cfg.eps, beta=cfg.beta) + g_bp = bptt_param_grads(model, xbar, x, M, T1, cfg.eps) + fe = torch.cat([g.flatten() for g in g_ep]) + fb = torch.cat([g.flatten() for g in g_bp]) + cos = F.cosine_similarity(fe, fb, dim=0).item() + names = [n for n, _ in model.named_parameters()] + print(f"\nEP-vs-BPTT gradient cosine (global): {cos:.4f}") + for n, a, b in zip(names, g_ep, g_bp): + c = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + print(f" {n:6s} cos={c:+.3f} |ep|={a.norm():.3e} |bptt|={b.norm():.3e}") + print(f"\nSMOKE {'PASS' if (mono and cos > 0.6) else 'CHECK'} " + f"(want energy monotone & global cos>0.6)") + + +# --------------------------------------------------------------------------- # +def main(): + p = argparse.ArgumentParser() + p.add_argument('--mode', choices=['ep', 'tbpte', 'smoke'], default='smoke') + p.add_argument('--dataset', choices=['cifar10', 'fashionmnist'], default='cifar10') + p.add_argument('--steps', type=int, default=4000) + p.add_argument('--batch', type=int, default=128) + p.add_argument('--img', type=int, default=32); p.add_argument('--ch', type=int, default=3) + p.add_argument('--patch', type=int, default=8); p.add_argument('--stride', type=int, default=8) + p.add_argument('--D', type=int, default=128); p.add_argument('--heads', type=int, default=4) + p.add_argument('--dh', type=int, default=32); p.add_argument('--mem', type=int, default=256) + p.add_argument('--gamma', type=float, default=0.25) + p.add_argument('--T1', type=int, default=30); p.add_argument('--T2', type=int, default=5) + p.add_argument('--eps', type=float, default=0.5); p.add_argument('--beta', type=float, default=0.1) + p.add_argument('--mask_ratio', type=float, default=0.5) + p.add_argument('--lr', type=float, default=4e-4); p.add_argument('--lr_min', type=float, default=1e-6) + p.add_argument('--wd', type=float, default=3e-5); p.add_argument('--clip', type=float, default=10.0) + p.add_argument('--log_every', type=int, default=100); p.add_argument('--eval_every', type=int, default=1000) + p.add_argument('--seed', type=int, default=0) + p.add_argument('--out', type=str, default='/home/yurenh2/ept/runs') + p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') + cfg = p.parse_args() + print('config:', vars(cfg), flush=True) + if cfg.mode == 'smoke': + smoke(cfg) + else: + train(cfg) + + +if __name__ == '__main__': + main() |
