summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 11:22:48 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 11:22:48 -0500
commit61204b6010e403b4c61b093f2a208a881b20fa11 (patch)
tree059002d3d603af8727a613f98fbb829a34e44fac /experiments
parent17e53bcf6971d93bd7061d6376b485132b30c825 (diff)
Add EP baseline implementation (Scellier & Bengio 2017) for CIFAR MLP
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/ep_baseline.py316
1 files changed, 316 insertions, 0 deletions
diff --git a/experiments/ep_baseline.py b/experiments/ep_baseline.py
new file mode 100644
index 0000000..e2e9074
--- /dev/null
+++ b/experiments/ep_baseline.py
@@ -0,0 +1,316 @@
+"""
+Equilibrium Propagation (Scellier & Bengio 2017) for ResidualMLP on CIFAR-10.
+Feedforward EP with energy-based state optimization.
+
+Usage: python ep_baseline.py --method ep --seed 42 --gpu 0
+"""
+import os, sys, json, argparse, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+import torchvision, torchvision.transforms as transforms
+
+
+def get_cifar10(bs=128):
+ tt = transforms.Compose([transforms.RandomCrop(32, 4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
+ tv = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
+ return (DataLoader(torchvision.datasets.CIFAR10('./data', True, download=True, transform=tt), bs, True, num_workers=4, pin_memory=True),
+ DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), bs, False, num_workers=4, pin_memory=True))
+
+
+def evaluate(m, tl, dev):
+ m.eval(); c, t = 0, 0
+ with torch.no_grad():
+ for x, y in tl:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ c += (m(x).argmax(1) == y).sum().item(); t += x.size(0)
+ return c / t
+
+
+def ep_energy(model, hiddens, lam=1.0):
+ """
+ Compute the EP energy E = 0.5 * sum_l ||h_{l+1} - h_l - F_l(h_l)||^2
+ hiddens: list of L+1 tensors, [h_0, h_1, ..., h_L]
+ F_l(h_l) is the residual branch output (block forward without the skip).
+ lam: weight for the state consistency term (kept at 1.0).
+ """
+ L = model.num_blocks
+ E = 0.0
+ for l in range(L):
+ f_l = model.blocks[l](hiddens[l]) # residual branch
+ residual = hiddens[l + 1] - hiddens[l] - f_l
+ E = E + 0.5 * (residual ** 2).sum(-1) # (batch,)
+ return E # (batch,)
+
+
+def ep_free_phase(model, x):
+ """
+ Free phase: standard forward pass. Returns hidden states h_0..h_L.
+ """
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+ return hiddens # list of L+1 tensors
+
+
+def ep_nudged_phase(model, x, y, h_free, beta, T_nudge, alpha_nudge):
+ """
+ Nudged phase: minimize E(h) + beta * C(h_L, y) w.r.t. hidden states h_1..h_L.
+ h_0 is fixed (output of embed layer).
+ Returns list of nudged hidden states [h_0, h_1^*, ..., h_L^*].
+ """
+ L = model.num_blocks
+ # Initialize nudged states from free phase (detached)
+ h_nudged = [h.clone().detach() for h in h_free]
+ # h_0 is fixed (embed output)
+ h_nudged[0] = h_free[0].clone().detach()
+ # Optimize h_1 .. h_L
+ for i in range(1, L + 1):
+ h_nudged[i].requires_grad_(True)
+
+ params_to_opt = h_nudged[1:]
+ inner_opt = optim.SGD(params_to_opt, lr=alpha_nudge)
+
+ for _ in range(T_nudge):
+ # Energy over all layers
+ E = ep_energy(model, h_nudged) # (batch,)
+ # Cost at output layer: cross-entropy
+ logits = model.out_head(model.out_ln(h_nudged[L]))
+ C = F.cross_entropy(logits, y, reduction='none') # (batch,)
+ total = (E + beta * C).mean()
+ inner_opt.zero_grad()
+ total.backward()
+ inner_opt.step()
+
+ return [h.detach() for h in h_nudged]
+
+
+
+def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01,
+ beta=0.5, T_nudge=20, alpha_nudge=0.1):
+ L = model.num_blocks
+
+ # Separate optimizers for different parts
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd)
+ all_opts = block_opts + [embed_opt, head_opt]
+ schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts]
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in trl:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+
+ # ---- FREE PHASE ----
+ # Standard forward pass to get free fixed point
+ with torch.no_grad():
+ _, h_free = model(x, return_hidden=True)
+
+ # ---- NUDGED PHASE ----
+ # Minimize E(h) + beta * C(h_L, y) w.r.t. hidden states
+ h_nudged = ep_nudged_phase(model, x, y, h_free, beta, T_nudge, alpha_nudge)
+
+ # ---- EP WEIGHT UPDATE ----
+ # Δθ ∝ (∂E_nudged/∂θ - ∂E_free/∂θ) / beta
+ # For blocks: dE/dθ_l comes from F_l(h_l) term in E
+ # For embed: dE/dθ_embed comes from h_0 = embed(x) being the base state
+
+ for o in all_opts:
+ o.zero_grad()
+
+ # Compute EP grads for residual blocks
+ # E = sum_l 0.5 ||h_{l+1} - h_l - F_l(h_l)||^2
+ # dE/dθ_l = - (h_{l+1} - h_l - F_l(h_l))^T * dF_l/dθ_l
+ # = - residual_l^T * dF_l/dθ_l
+
+ for l in range(L):
+ h_l_free = h_free[l].detach()
+ h_lp1_free = h_free[l + 1].detach()
+ h_l_nudge = h_nudged[l].detach()
+ h_lp1_nudge = h_nudged[l + 1].detach()
+
+ # Free phase: -residual_l_free dot dF_l/dtheta
+ # = -(h_{l+1}^free - h_l^free - F_l(h_l^free)) dot dF_l/dtheta
+ # We compute this by forward pass with grad for params
+ f_l_free = model.blocks[l](h_l_free)
+ res_free = h_lp1_free - h_l_free - f_l_free.detach()
+ # Gradient: d/dtheta [ -0.5 * res_free^2 ] = res_free * dF/dtheta ... actually we want
+ # to minimize E, so grad = dE/dtheta = -res * dF/dtheta
+ # To use autograd: compute -res_free.detach() * f_l_free, sum, backward
+ loss_free_l = -(res_free.detach() * f_l_free).sum()
+
+ f_l_nudge = model.blocks[l](h_l_nudge)
+ res_nudge = h_lp1_nudge - h_l_nudge - f_l_nudge.detach()
+ loss_nudge_l = -(res_nudge.detach() * f_l_nudge).sum()
+
+ # EP grad = (nudged - free) / beta [we want d(E_nudge - E_free)/dtheta / beta]
+ # Since loss_free_l = -res_free * F_l contributes dE/dtheta_free (the negative),
+ # and loss_nudge_l similarly, we need:
+ # grad_l = (dE_nudge/dtheta - dE_free/dtheta) / beta
+ # dE/dtheta = -res * dF/dtheta => computed via backward of (res * F).sum()
+ # So: ep_loss = (loss_nudge_l - loss_free_l) / beta
+ ep_loss_l = (loss_nudge_l - loss_free_l) / beta
+ ep_loss_l.backward()
+
+ # Grad for embed layer:
+ # h_0 = embed(x), so dE/dtheta_embed = dE/dh_0 * dh_0/dtheta_embed
+ # dE/dh_0: E depends on h_0 via (h_1 - h_0 - F_0(h_0)) term
+ # = -(h_1 - h_0 - F_0(h_0)) * (I + dF_0/dh_0)^T ... complex
+ # Simpler: treat h_0 as part of the system and use chain rule via autograd
+ h0_free = model.embed(x) # differentiable
+ # compute E contribution from layer 0 with h_0 = embed(x) fixed, block params fixed
+ with torch.no_grad():
+ f0_free = model.blocks[0](h0_free.detach())
+ res0_free = h_free[1].detach() - h0_free.detach() - f0_free
+ embed_loss_free = -(res0_free.detach() * h0_free).sum() / beta # approximation: -res * dh0/dtheta
+
+ h0_nudge_rg = model.embed(x)
+ with torch.no_grad():
+ f0_nudge = model.blocks[0](h_nudged[0].detach())
+ res0_nudge = h_nudged[1].detach() - h_nudged[0].detach() - f0_nudge
+ embed_loss_nudge = -(res0_nudge.detach() * h0_nudge_rg).sum() / beta
+
+ embed_ep = (embed_loss_nudge - embed_loss_free)
+ embed_ep.backward()
+
+ # Grad for out_head + out_ln: standard BP at nudged output
+ # EP: Δθ_out = ∂C_nudged/∂θ_out (since ∂E/∂θ_out = 0)
+ # This is equivalent to standard BP loss at nudged hidden state
+ logits_nudged = model.out_head(model.out_ln(h_nudged[L].detach()))
+ head_loss = F.cross_entropy(logits_nudged, y)
+ head_loss.backward()
+
+ # Clip and step
+ all_params = list(model.parameters())
+ torch.nn.utils.clip_grad_norm_(all_params, 1.0)
+ for o in all_opts:
+ o.step()
+
+ for s in schedulers:
+ s.step()
+ if ep % 20 == 0:
+ print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True)
+
+ return model
+
+
+def ep_credit_signals(model, x, y, beta, T_nudge, alpha_nudge):
+ """
+ Compute EP credit signals a_l^EP = (h_l^nudged - h_l^free) / beta for diagnostics.
+ """
+ with torch.no_grad():
+ _, h_free = model(x, return_hidden=True)
+ h_nudged = ep_nudged_phase(model, x, y, h_free, beta, T_nudge, alpha_nudge)
+ L = model.num_blocks
+ credits = [(h_nudged[l] - h_free[l]) / beta for l in range(L)]
+ return credits, h_free, h_nudged
+
+
+def compute_diagnostics(model, tel, dev, beta, T_nudge, alpha_nudge):
+ model.eval()
+ L = model.num_blocks
+
+ for x, y in tel:
+ x = x.view(x.size(0), -1).to(dev); y = y.to(dev)
+ break
+
+ # EP credit signals
+ ep_credits, h_free, h_nudged = ep_credit_signals(model, x, y, beta, T_nudge, alpha_nudge)
+
+ # BP gradients for comparison
+ h0 = model.embed(x.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for bl in model.blocks:
+ hs.append(hs[-1] + bl(hs[-1]))
+ lo = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(lo, y)
+ gs = torch.autograd.grad(loss, hs)
+ bp_grads = {l: gs[l].detach() for l in range(L)}
+
+ # Gamma: cosine sim between EP credit and BP grad
+ gammas = []
+ for l in range(L):
+ g = cosine_similarity_batch(ep_credits[l], bp_grads[l])
+ gammas.append(g)
+
+ # rho: perturbation correlation using EP credit
+ rhos = []
+ with torch.no_grad():
+ _, hi = model(x, return_hidden=True)
+
+ for l in range(L):
+ h_l = hi[l].detach()
+ a_l = ep_credits[l].detach()
+
+ def mk(sl):
+ def f(h):
+ with torch.no_grad():
+ c = h
+ for i in range(sl, L):
+ c = c + model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none')
+ return f
+
+ rhos.append(perturbation_correlation(h_l, a_l, mk(l), epsilon=1e-3, M=16))
+
+ # naive state error
+ with torch.no_grad():
+ _, hi2 = model(x, return_hidden=True)
+ nse = ((hi2[L // 2] - hi2[-1]).norm(-1) / hi2[-1].norm(-1).clamp(min=1e-8)).mean().item()
+
+ return {'Gamma': float(np.mean(gammas)), 'rho': float(np.mean(rhos)),
+ 'naive_StateErr': nse, 'gammas_per_layer': [float(g) for g in gammas],
+ 'rhos_per_layer': [float(r) for r in rhos]}
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--method', type=str, default='ep')
+ p.add_argument('--seed', type=int, required=True)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/ep_baseline')
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--beta', type=float, default=0.5, help='EP nudge strength')
+ p.add_argument('--T_nudge', type=int, default=20, help='Inner optimization steps for nudged phase')
+ p.add_argument('--alpha_nudge', type=float, default=0.1, help='Inner step size for nudged phase')
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ dev = torch.device(f'cuda:{args.gpu}')
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+
+ trl, tel = get_cifar10()
+ L, d = 4, 256
+ model = ResidualMLP(3072, d, 10, L).to(dev)
+
+ print(f"[{args.method} s={args.seed}] Training EP beta={args.beta} T={args.T_nudge} alpha={args.alpha_nudge}", flush=True)
+
+ model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd,
+ beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge)
+
+ acc = evaluate(model, tel, dev)
+ diag = compute_diagnostics(model, tel, dev, args.beta, args.T_nudge, args.alpha_nudge)
+
+ torch.save(model.state_dict(), os.path.join(args.output_dir, f'{args.method}_s{args.seed}.pt'))
+
+ result = {'method': args.method, 'seed': args.seed, 'acc': acc,
+ 'Gamma': diag['Gamma'], 'rho': diag['rho'],
+ 'naive_StateErr': diag['naive_StateErr'],
+ 'gammas_per_layer': diag['gammas_per_layer'],
+ 'rhos_per_layer': diag['rhos_per_layer'],
+ 'beta': args.beta, 'T_nudge': args.T_nudge, 'alpha_nudge': args.alpha_nudge}
+
+ with open(os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json'), 'w') as f:
+ json.dump(result, f, indent=2, default=float)
+
+ print(f"[{args.method} s={args.seed}] acc={acc:.4f} Γ={diag['Gamma']:.4f} ρ={diag['rho']:.4f} nse={diag['naive_StateErr']:.4f}", flush=True)
+
+
+if __name__ == '__main__':
+ main()