""" CNN baseline for CIFAR-10: BP / DFA / EP on a small ConvNet. One method+seed per invocation for clean process isolation. Architecture: Conv2d(3,32,3,padding=1) -> ReLU Conv2d(32,64,3,padding=1) -> ReLU -> MaxPool(2) [32->16] Conv2d(64,128,3,padding=1) -> ReLU -> MaxPool(2) [16->8] flatten -> FC(128*8*8=8192, 256) -> ReLU -> FC(256, 10) Blocks (for local update): block 0 : Conv1 (Conv2d 3->32) block 1 : Conv2 (Conv2d 32->64) + MaxPool block 2 : Conv3 (Conv2d 64->128) + MaxPool block 3 : FC1 (Linear 8192->256) block 4 : FC2 (Linear 256->10) -- output head, always trained with loss Hidden states (post-activation, for credit): h0 : (B, 32, 32, 32) after Conv1+ReLU h1 : (B, 64, 16, 16) after Conv2+ReLU+MaxPool h2 : (B, 128, 8, 8) after Conv3+ReLU+MaxPool h3 : (B, 256) after flatten+FC1+ReLU DFA: flatten each h_l to (B, d_l), random feedback B_l: (d_l, 10) EP: energy E = sum_l 0.5 ||h_{l+1} - F_l(h_l)||^2 adapted for CNN Usage: python cnn_baseline.py --method bp --seed 42 --gpu 0 Output: results/cnn_baseline/{method}_s{seed}.json + .pt checkpoint """ import os, sys, json, argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation import torchvision, torchvision.transforms as transforms # --------------------------------------------------------------------------- # Data # --------------------------------------------------------------------------- def get_cifar10(bs=128): tt = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trl = DataLoader( torchvision.datasets.CIFAR10('./data', True, download=True, transform=tt), bs, True, num_workers=4, pin_memory=True) tel = DataLoader( torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), bs, False, num_workers=4, pin_memory=True) return trl, tel # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- class SmallCNN(nn.Module): """ A small 3-conv CNN for CIFAR-10. Blocks (nn.Module list, mirrors the 5-block treatment): blocks[0] : Conv1 layer (Conv2d 3->32, BN, ReLU) blocks[1] : Conv2 layer (Conv2d 32->64, BN, ReLU, MaxPool) blocks[2] : Conv3 layer (Conv2d 64->128, BN, ReLU, MaxPool) blocks[3] : FC1 layer (Linear 8192->256, ReLU) out_head : FC2 layer (Linear 256->10) forward(x, return_hidden=False): returns logits, or (logits, [h0, h1, h2, h3]) when return_hidden=True. h_l are post-activation tensors; h3 is (B,256) flat. """ # flat dim of each hidden state FLAT_DIMS = [32 * 32 * 32, 64 * 16 * 16, 128 * 8 * 8, 256] NUM_BLOCKS = 4 # conv1, conv2, conv3, fc1 (out_head is separate) def __init__(self): super().__init__() self.blocks = nn.ModuleList([ # block 0: Conv1 nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), ), # block 1: Conv2 + MaxPool nn.Sequential( nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2), ), # block 2: Conv3 + MaxPool nn.Sequential( nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2), ), # block 3: FC1 nn.Sequential( nn.Linear(128 * 8 * 8, 256), nn.ReLU(inplace=True), ), ]) self.out_head = nn.Linear(256, 10) self.num_blocks = self.NUM_BLOCKS self.flat_dims = self.FLAT_DIMS def forward(self, x, return_hidden=False): """ x: (B, 3, 32, 32) Returns logits (B,10), optionally with list of 4 hidden states. h0: (B,32,32,32) h1: (B,64,16,16) h2: (B,128,8,8) h3: (B,256) """ h0 = self.blocks[0](x) # (B, 32, 32, 32) h1 = self.blocks[1](h0) # (B, 64, 16, 16) h2 = self.blocks[2](h1) # (B, 128, 8, 8) h3 = self.blocks[3](h2.flatten(1)) # (B, 256) logits = self.out_head(h3) # (B, 10) if return_hidden: return logits, [h0, h1, h2, h3] return logits def forward_from(self, h, layer_idx): """ Run the network from hidden state h at layer `layer_idx` to logits. layer_idx in {0, 1, 2, 3} (0=after block0, 3=after block3). h should be the post-activation tensor at that layer. """ c = h for i in range(layer_idx + 1, self.num_blocks): if i == 3: c = self.blocks[i](c.flatten(1) if c.dim() > 2 else c) else: c = self.blocks[i](c) if c.dim() > 2: c = c.flatten(1) logits = self.out_head(c if c.dim() == 2 else c.flatten(1)) return logits def evaluate(model, loader, dev): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in loader: x, y = x.to(dev), y.to(dev) correct += (model(x).argmax(1) == y).sum().item() total += x.size(0) return correct / total # --------------------------------------------------------------------------- # Helper: flatten hidden state for credit computation # --------------------------------------------------------------------------- def flat(h): """Flatten spatial dims: (B, C, H, W) -> (B, C*H*W) or (B, D) -> (B, D).""" return h.flatten(1) if h.dim() > 2 else h # --------------------------------------------------------------------------- # Training: BP # --------------------------------------------------------------------------- def train_bp(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01): opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for x, y in trl: x, y = x.to(dev), y.to(dev) F.cross_entropy(model(x), y).backward() opt.step() opt.zero_grad() sch.step() if ep % 20 == 0: print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) return model # --------------------------------------------------------------------------- # Training: DFA # --------------------------------------------------------------------------- def train_dfa(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01): """ Direct Feedback Alignment for CNN. For each block l, a random matrix B_l: (flat_dim_l, 10) maps the global error signal e_T (softmax-CE gradient at output) back to the hidden space. The local surrogate loss is: L_l = < F_l(h_{l-1}), a_l / ||a_l||_rms > where a_l = B_l @ e_T (flattened credit, then reshaped if needed). The out_head is trained with standard cross-entropy on the final hidden state. """ L = model.num_blocks # 4 blocks (conv1, conv2, conv3, fc1) C = 10 flat_dims = model.flat_dims # [32768, 16384, 8192, 256] # Random feedback matrices (fixed, not trained) Bs = [torch.randn(flat_dims[l], C, device=dev) / np.sqrt(C) for l in range(L)] # Per-block optimizers + head optimizer block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) all_opts = block_opts + [head_opt] schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts] for ep in range(1, epochs + 1): model.train() for x, y in trl: x, y = x.to(dev), y.to(dev) B = x.size(0) # Forward pass (no grad) to get hidden states and global error with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) probs = logits.softmax(-1) # (B, 10) e_T = probs.clone() e_T[torch.arange(B), y] -= 1.0 # (B, 10) # --- Train out_head with standard CE on detached h3 --- h3_det = hiddens[3].detach() ce_loss = F.cross_entropy(model.out_head(h3_det), y) head_opt.zero_grad() ce_loss.backward() head_opt.step() # --- Train each block with DFA local surrogate --- # For conv blocks (l=0,1,2) we need to re-run the block forward # starting from the *previous* hidden state. # The "input" to block l is: # l=0: x (raw input image) # l=1: h0 # l=2: h1 # l=3: h2 (flattened) inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()] for l in range(L): # Compute DFA credit signal (flattened) a_l_flat = (e_T @ Bs[l].T).detach() # (B, flat_dim_l) rms = (a_l_flat ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_l_norm = a_l_flat / rms # (B, flat_dim_l) # Forward through block l with grad inp = inputs[l].detach() if l == 3: out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) else: out_l = model.blocks[l](inp) # Local surrogate: (summed over spatial, averaged over batch) out_flat = flat(out_l) # (B, flat_dim_l) local_loss = (out_flat * a_l_norm).sum(-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() for s in schedulers: s.step() if ep % 20 == 0: print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) return model # --------------------------------------------------------------------------- # Training: EP (Equilibrium Propagation adapted for CNN) # --------------------------------------------------------------------------- def ep_energy_cnn(model, hiddens, x): """ CNN EP energy: E = sum_l 0.5 ||h_l - F_l(inp_l)||^2 (flattened). hiddens[0] = h0 (B,32,32,32) -- target for block 0 applied to x hiddens[1] = h1 (B,64,16,16) -- target for block 1 applied to h0 hiddens[2] = h2 (B,128,8,8) -- target for block 2 applied to h1 hiddens[3] = h3 (B,256) -- target for block 3 applied to h2.flatten """ inputs = [x, hiddens[0], hiddens[1], hiddens[2]] E = 0.0 for l in range(model.num_blocks): inp = inputs[l] if l == 3: pred = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) else: pred = model.blocks[l](inp) # Compare flattened versions pred_f = flat(pred) h_f = flat(hiddens[l]) residual = h_f - pred_f # (B, d_l) E = E + 0.5 * (residual ** 2).sum(-1) # (B,) return E def ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge): """ Nudged phase: minimize E(h) + beta * CE(out_head(h3), y) w.r.t. h0, h1, h2, h3 (all free hidden states). x is fixed (pixel input, not a hidden state). """ L = model.num_blocks # Initialise from free phase h_nudged = [h.clone().detach() for h in h_free] for i in range(L): h_nudged[i].requires_grad_(True) inner_opt = optim.SGD(h_nudged, lr=alpha_nudge) for _ in range(T_nudge): E = ep_energy_cnn(model, h_nudged, x) # (B,) logits = model.out_head(h_nudged[3]) # (B, 10) C_loss = F.cross_entropy(logits, y, reduction='none') # (B,) total = (E + beta * C_loss).mean() inner_opt.zero_grad() total.backward() inner_opt.step() return [h.detach() for h in h_nudged] def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, beta=0.5, T_nudge=20, alpha_nudge=0.05): """ Equilibrium Propagation for the small CNN. Weight update rule: Δθ ∝ (dE_nudged/dθ - dE_free/dθ) / beta For the out_head: standard CE on nudged output (no dE/dtheta_head term). """ L = model.num_blocks block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) all_opts = block_opts + [head_opt] schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts] for ep in range(1, epochs + 1): model.train() for x, y in trl: x, y = x.to(dev), y.to(dev) # Free phase: standard forward pass with torch.no_grad(): _, h_free = model(x, return_hidden=True) # Nudged phase h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) # Zero all grads for o in all_opts: o.zero_grad() # EP weight update per block: # dE/dtheta_l = -residual_l * dF_l/dtheta_l (same as MLP EP) inputs_free = [x, h_free[0].detach(), h_free[1].detach(), h_free[2].detach()] inputs_nudge = [x, h_nudged[0].detach(), h_nudged[1].detach(), h_nudged[2].detach()] for l in range(L): inp_f = inputs_free[l].detach() inp_n = inputs_nudge[l].detach() if l == 3: f_free = model.blocks[l](inp_f.flatten(1) if inp_f.dim() > 2 else inp_f) f_nudge = model.blocks[l](inp_n.flatten(1) if inp_n.dim() > 2 else inp_n) else: f_free = model.blocks[l](inp_f) f_nudge = model.blocks[l](inp_n) # residuals (detached target - computed output) res_free = (flat(h_free[l]).detach() - flat(f_free).detach()) # (B, d_l) res_nudge = (flat(h_nudged[l]).detach() - flat(f_nudge).detach()) # dE/dtheta = -(res * dF/dtheta) => gradient via autograd trick # loss_free_l = -(res_free * f_l_free).sum() gives dE_free/dtheta # loss_nudge_l = -(res_nudge * f_l_nudge).sum() gives dE_nudge/dtheta loss_free_l = -(res_free * flat(f_free)).sum() loss_nudge_l = -(res_nudge * flat(f_nudge)).sum() ep_loss_l = (loss_nudge_l - loss_free_l) / beta ep_loss_l.backward() # Head: CE on nudged h3 head_loss = F.cross_entropy(model.out_head(h_nudged[3].detach()), y) head_loss.backward() torch.nn.utils.clip_grad_norm_(list(model.parameters()), 1.0) for o in all_opts: o.step() for s in schedulers: s.step() if ep % 20 == 0: print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) return model # --------------------------------------------------------------------------- # Diagnostics # --------------------------------------------------------------------------- def compute_bp_grads(model, x, y): """ Compute BP gradients w.r.t. each hidden state h_l via autograd. Returns list of grad tensors (same shape as h_l), and the hidden states. """ model.eval() L = model.num_blocks # Re-run forward with requires_grad on intermediate activations # We build the forward manually to hook into each h_l h = [None] * L inp = x for l in range(L): if l == 3: inp = inp.flatten(1) if inp.dim() > 2 else inp h[l] = model.blocks[l](inp.detach().requires_grad_(False)) h[l] = h[l].detach().requires_grad_(True) inp = h[l] logits = model.out_head(h[3]) loss = F.cross_entropy(logits, y) gs = torch.autograd.grad(loss, h, allow_unused=True) return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nudge=0.05): model.eval() L = model.num_blocks # Grab one batch for x, y in tel: x, y = x.to(dev), y.to(dev) break # BP gradients bp_grads, h_bp = compute_bp_grads(model, x, y) # Credit signals depending on method if method == 'ep': with torch.no_grad(): _, h_free = model(x, return_hidden=True) h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] else: # For BP and DFA, use BP grads directly (BP self-cosine = 1 by definition) credits = [flat(bp_grads[l]) for l in range(L)] bp_grads_flat = [flat(g) for g in bp_grads] # Gamma: cosine similarity between credit and BP grad gammas = [] for l in range(L): g = cosine_similarity_batch(credits[l], bp_grads_flat[l]) gammas.append(float(g)) # rho: perturbation correlation using forward_from with torch.no_grad(): _, hiddens = model(x, return_hidden=True) rhos = [] for l in range(L): h_l = flat(hiddens[l].detach()) # (B, d_l) a_l = credits[l].detach() # (B, d_l) # forward_fn: perturbed flat h_l -> per-sample CE loss # we need to run from layer l+1 onward def make_forward_fn(layer_idx): def forward_fn(h_flat): """h_flat: (B, d_l) flat tensor at layer layer_idx output.""" with torch.no_grad(): # Reshape back to spatial if needed c = h_flat for i in range(layer_idx + 1, L): if i == 3: c = model.blocks[i](c.flatten(1) if c.dim() > 2 else c) else: # blocks 1,2 expect spatial input; but c here is flat # only happens for i=1 (in_dim 32*32*32->spatial 32,32,32) # and i=2 (64,16,16). Since layer_idx 2: c = c.flatten(1) logits = model.out_head(c) return F.cross_entropy(logits, y, reduction='none') return forward_fn rho = perturbation_correlation(h_l, a_l, make_forward_fn(l), epsilon=1e-3, M=16) rhos.append(float(rho)) return { 'Gamma': float(np.mean(gammas)), 'rho': float(np.mean(rhos)), 'gammas_per_layer': gammas, 'rhos_per_layer': rhos, } # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): p = argparse.ArgumentParser(description='CNN baseline for CIFAR-10') p.add_argument('--method', type=str, required=True, choices=['bp', 'dfa', 'ep']) p.add_argument('--seed', type=int, required=True) p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/cnn_baseline') p.add_argument('--epochs', type=int, default=100) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--wd', type=float, default=0.01) # EP hyperparameters p.add_argument('--beta', type=float, default=0.5) p.add_argument('--T_nudge', type=int, default=20) p.add_argument('--alpha_nudge', type=float, default=0.05) args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) dev = torch.device(f'cuda:{args.gpu}') torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) trl, tel = get_cifar10() model = SmallCNN().to(dev) print(f"[{args.method} s={args.seed}] Training CNN on CIFAR-10 for {args.epochs} epochs...", flush=True) if args.method == 'bp': model = train_bp(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd) elif args.method == 'dfa': model = train_dfa(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd) elif args.method == 'ep': model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge) acc = evaluate(model, tel, dev) diag = compute_diagnostics(model, tel, dev, args.method, beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge) # Save checkpoint ckpt_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.pt') torch.save(model.state_dict(), ckpt_path) result = { 'method': args.method, 'seed': args.seed, 'acc': float(acc), 'Gamma': diag['Gamma'], 'rho': diag['rho'], 'gammas_per_layer': diag['gammas_per_layer'], 'rhos_per_layer': diag['rhos_per_layer'], 'epochs': args.epochs, 'lr': args.lr, 'wd': args.wd, 'beta': args.beta, 'T_nudge': args.T_nudge, 'alpha_nudge': args.alpha_nudge, } json_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json') with open(json_path, 'w') as f: json.dump(result, f, indent=2, default=float) print( f"[{args.method} s={args.seed}] acc={acc:.4f} " f"Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}", flush=True, ) print(f" gammas_per_layer={[f'{g:.4f}' for g in diag['gammas_per_layer']]}", flush=True) print(f" rhos_per_layer ={[f'{r:.4f}' for r in diag['rhos_per_layer']]}", flush=True) print(f" Saved: {json_path}", flush=True) if __name__ == '__main__': main()