diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 11:28:13 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-02 11:28:13 -0500 |
| commit | ef80d52840a1c6fb7f9a22985784ce311edc59a4 (patch) | |
| tree | 00e674c9bf71f31cecd330a115d75cb8a417cea8 /experiments | |
| parent | 61204b6010e403b4c61b093f2a208a881b20fa11 (diff) | |
Add CNN baseline: SmallCNN with BP/DFA/EP on CIFAR-10
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cnn_baseline.py | 600 |
1 files changed, 600 insertions, 0 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py new file mode 100644 index 0000000..f55b77b --- /dev/null +++ b/experiments/cnn_baseline.py @@ -0,0 +1,600 @@ +""" +CNN baseline for CIFAR-10: BP / DFA / EP on a small ConvNet. +One method+seed per invocation for clean process isolation. + +Architecture: + Conv2d(3,32,3,padding=1) -> ReLU + Conv2d(32,64,3,padding=1) -> ReLU -> MaxPool(2) [32->16] + Conv2d(64,128,3,padding=1) -> ReLU -> MaxPool(2) [16->8] + flatten -> FC(128*8*8=8192, 256) -> ReLU -> FC(256, 10) + +Blocks (for local update): + block 0 : Conv1 (Conv2d 3->32) + block 1 : Conv2 (Conv2d 32->64) + MaxPool + block 2 : Conv3 (Conv2d 64->128) + MaxPool + block 3 : FC1 (Linear 8192->256) + block 4 : FC2 (Linear 256->10) -- output head, always trained with loss + +Hidden states (post-activation, for credit): + h0 : (B, 32, 32, 32) after Conv1+ReLU + h1 : (B, 64, 16, 16) after Conv2+ReLU+MaxPool + h2 : (B, 128, 8, 8) after Conv3+ReLU+MaxPool + h3 : (B, 256) after flatten+FC1+ReLU + +DFA: flatten each h_l to (B, d_l), random feedback B_l: (d_l, 10) +EP: energy E = sum_l 0.5 ||h_{l+1} - F_l(h_l)||^2 adapted for CNN + +Usage: python cnn_baseline.py --method bp --seed 42 --gpu 0 +Output: results/cnn_baseline/{method}_s{seed}.json + .pt checkpoint +""" + +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation +import torchvision, torchvision.transforms as transforms + + +# --------------------------------------------------------------------------- +# Data +# --------------------------------------------------------------------------- + +def get_cifar10(bs=128): + tt = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trl = DataLoader( + torchvision.datasets.CIFAR10('./data', True, download=True, transform=tt), + bs, True, num_workers=4, pin_memory=True) + tel = DataLoader( + torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), + bs, False, num_workers=4, pin_memory=True) + return trl, tel + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + +class SmallCNN(nn.Module): + """ + A small 3-conv CNN for CIFAR-10. + + Blocks (nn.Module list, mirrors the 5-block treatment): + blocks[0] : Conv1 layer (Conv2d 3->32, BN, ReLU) + blocks[1] : Conv2 layer (Conv2d 32->64, BN, ReLU, MaxPool) + blocks[2] : Conv3 layer (Conv2d 64->128, BN, ReLU, MaxPool) + blocks[3] : FC1 layer (Linear 8192->256, ReLU) + out_head : FC2 layer (Linear 256->10) + + forward(x, return_hidden=False): + returns logits, or (logits, [h0, h1, h2, h3]) when return_hidden=True. + h_l are post-activation tensors; h3 is (B,256) flat. + """ + # flat dim of each hidden state + FLAT_DIMS = [32 * 32 * 32, 64 * 16 * 16, 128 * 8 * 8, 256] + NUM_BLOCKS = 4 # conv1, conv2, conv3, fc1 (out_head is separate) + + def __init__(self): + super().__init__() + self.blocks = nn.ModuleList([ + # block 0: Conv1 + nn.Sequential( + nn.Conv2d(3, 32, 3, padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + ), + # block 1: Conv2 + MaxPool + nn.Sequential( + nn.Conv2d(32, 64, 3, padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ), + # block 2: Conv3 + MaxPool + nn.Sequential( + nn.Conv2d(64, 128, 3, padding=1), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ), + # block 3: FC1 + nn.Sequential( + nn.Linear(128 * 8 * 8, 256), + nn.ReLU(inplace=True), + ), + ]) + self.out_head = nn.Linear(256, 10) + self.num_blocks = self.NUM_BLOCKS + self.flat_dims = self.FLAT_DIMS + + def forward(self, x, return_hidden=False): + """ + x: (B, 3, 32, 32) + Returns logits (B,10), optionally with list of 4 hidden states. + h0: (B,32,32,32) h1: (B,64,16,16) h2: (B,128,8,8) h3: (B,256) + """ + h0 = self.blocks[0](x) # (B, 32, 32, 32) + h1 = self.blocks[1](h0) # (B, 64, 16, 16) + h2 = self.blocks[2](h1) # (B, 128, 8, 8) + h3 = self.blocks[3](h2.flatten(1)) # (B, 256) + logits = self.out_head(h3) # (B, 10) + if return_hidden: + return logits, [h0, h1, h2, h3] + return logits + + def forward_from(self, h, layer_idx): + """ + Run the network from hidden state h at layer `layer_idx` to logits. + layer_idx in {0, 1, 2, 3} (0=after block0, 3=after block3). + h should be the post-activation tensor at that layer. + """ + c = h + for i in range(layer_idx + 1, self.num_blocks): + if i == 3: + c = self.blocks[i](c.flatten(1) if c.dim() > 2 else c) + else: + c = self.blocks[i](c) + if c.dim() > 2: + c = c.flatten(1) + logits = self.out_head(c if c.dim() == 2 else c.flatten(1)) + return logits + + +def evaluate(model, loader, dev): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + correct += (model(x).argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +# --------------------------------------------------------------------------- +# Helper: flatten hidden state for credit computation +# --------------------------------------------------------------------------- + +def flat(h): + """Flatten spatial dims: (B, C, H, W) -> (B, C*H*W) or (B, D) -> (B, D).""" + return h.flatten(1) if h.dim() > 2 else h + + +# --------------------------------------------------------------------------- +# Training: BP +# --------------------------------------------------------------------------- + +def train_bp(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01): + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for x, y in trl: + x, y = x.to(dev), y.to(dev) + F.cross_entropy(model(x), y).backward() + opt.step() + opt.zero_grad() + sch.step() + if ep % 20 == 0: + print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) + return model + + +# --------------------------------------------------------------------------- +# Training: DFA +# --------------------------------------------------------------------------- + +def train_dfa(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01): + """ + Direct Feedback Alignment for CNN. + + For each block l, a random matrix B_l: (flat_dim_l, 10) maps the global + error signal e_T (softmax-CE gradient at output) back to the hidden space. + The local surrogate loss is: + L_l = < F_l(h_{l-1}), a_l / ||a_l||_rms > + where a_l = B_l @ e_T (flattened credit, then reshaped if needed). + The out_head is trained with standard cross-entropy on the final hidden state. + """ + L = model.num_blocks # 4 blocks (conv1, conv2, conv3, fc1) + C = 10 + flat_dims = model.flat_dims # [32768, 16384, 8192, 256] + + # Random feedback matrices (fixed, not trained) + Bs = [torch.randn(flat_dims[l], C, device=dev) / np.sqrt(C) for l in range(L)] + + # Per-block optimizers + head optimizer + block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_opts = block_opts + [head_opt] + schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts] + + for ep in range(1, epochs + 1): + model.train() + for x, y in trl: + x, y = x.to(dev), y.to(dev) + B = x.size(0) + + # Forward pass (no grad) to get hidden states and global error + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + probs = logits.softmax(-1) # (B, 10) + e_T = probs.clone() + e_T[torch.arange(B), y] -= 1.0 # (B, 10) + + # --- Train out_head with standard CE on detached h3 --- + h3_det = hiddens[3].detach() + ce_loss = F.cross_entropy(model.out_head(h3_det), y) + head_opt.zero_grad() + ce_loss.backward() + head_opt.step() + + # --- Train each block with DFA local surrogate --- + # For conv blocks (l=0,1,2) we need to re-run the block forward + # starting from the *previous* hidden state. + # The "input" to block l is: + # l=0: x (raw input image) + # l=1: h0 + # l=2: h1 + # l=3: h2 (flattened) + + inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()] + + for l in range(L): + # Compute DFA credit signal (flattened) + a_l_flat = (e_T @ Bs[l].T).detach() # (B, flat_dim_l) + rms = (a_l_flat ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_l_norm = a_l_flat / rms # (B, flat_dim_l) + + # Forward through block l with grad + inp = inputs[l].detach() + if l == 3: + out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) + else: + out_l = model.blocks[l](inp) + + # Local surrogate: <F_l(inp), a_l_norm> (summed over spatial, averaged over batch) + out_flat = flat(out_l) # (B, flat_dim_l) + local_loss = (out_flat * a_l_norm).sum(-1).mean() + + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + for s in schedulers: + s.step() + if ep % 20 == 0: + print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) + + return model + + +# --------------------------------------------------------------------------- +# Training: EP (Equilibrium Propagation adapted for CNN) +# --------------------------------------------------------------------------- + +def ep_energy_cnn(model, hiddens, x): + """ + CNN EP energy: E = sum_l 0.5 ||h_l - F_l(inp_l)||^2 (flattened). + + hiddens[0] = h0 (B,32,32,32) -- target for block 0 applied to x + hiddens[1] = h1 (B,64,16,16) -- target for block 1 applied to h0 + hiddens[2] = h2 (B,128,8,8) -- target for block 2 applied to h1 + hiddens[3] = h3 (B,256) -- target for block 3 applied to h2.flatten + """ + inputs = [x, hiddens[0], hiddens[1], hiddens[2]] + E = 0.0 + for l in range(model.num_blocks): + inp = inputs[l] + if l == 3: + pred = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) + else: + pred = model.blocks[l](inp) + # Compare flattened versions + pred_f = flat(pred) + h_f = flat(hiddens[l]) + residual = h_f - pred_f # (B, d_l) + E = E + 0.5 * (residual ** 2).sum(-1) # (B,) + return E + + +def ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge): + """ + Nudged phase: minimize E(h) + beta * CE(out_head(h3), y) + w.r.t. h0, h1, h2, h3 (all free hidden states). + x is fixed (pixel input, not a hidden state). + """ + L = model.num_blocks + # Initialise from free phase + h_nudged = [h.clone().detach() for h in h_free] + for i in range(L): + h_nudged[i].requires_grad_(True) + + inner_opt = optim.SGD(h_nudged, lr=alpha_nudge) + + for _ in range(T_nudge): + E = ep_energy_cnn(model, h_nudged, x) # (B,) + logits = model.out_head(h_nudged[3]) # (B, 10) + C_loss = F.cross_entropy(logits, y, reduction='none') # (B,) + total = (E + beta * C_loss).mean() + inner_opt.zero_grad() + total.backward() + inner_opt.step() + + return [h.detach() for h in h_nudged] + + +def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, + beta=0.5, T_nudge=20, alpha_nudge=0.05): + """ + Equilibrium Propagation for the small CNN. + Weight update rule: + Δθ ∝ (dE_nudged/dθ - dE_free/dθ) / beta + For the out_head: standard CE on nudged output (no dE/dtheta_head term). + """ + L = model.num_blocks + + block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + all_opts = block_opts + [head_opt] + schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts] + + for ep in range(1, epochs + 1): + model.train() + for x, y in trl: + x, y = x.to(dev), y.to(dev) + + # Free phase: standard forward pass + with torch.no_grad(): + _, h_free = model(x, return_hidden=True) + + # Nudged phase + h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) + + # Zero all grads + for o in all_opts: + o.zero_grad() + + # EP weight update per block: + # dE/dtheta_l = -residual_l * dF_l/dtheta_l (same as MLP EP) + inputs_free = [x, h_free[0].detach(), h_free[1].detach(), h_free[2].detach()] + inputs_nudge = [x, h_nudged[0].detach(), h_nudged[1].detach(), h_nudged[2].detach()] + + for l in range(L): + inp_f = inputs_free[l].detach() + inp_n = inputs_nudge[l].detach() + + if l == 3: + f_free = model.blocks[l](inp_f.flatten(1) if inp_f.dim() > 2 else inp_f) + f_nudge = model.blocks[l](inp_n.flatten(1) if inp_n.dim() > 2 else inp_n) + else: + f_free = model.blocks[l](inp_f) + f_nudge = model.blocks[l](inp_n) + + # residuals (detached target - computed output) + res_free = (flat(h_free[l]).detach() - flat(f_free).detach()) # (B, d_l) + res_nudge = (flat(h_nudged[l]).detach() - flat(f_nudge).detach()) + + # dE/dtheta = -(res * dF/dtheta) => gradient via autograd trick + # loss_free_l = -(res_free * f_l_free).sum() gives dE_free/dtheta + # loss_nudge_l = -(res_nudge * f_l_nudge).sum() gives dE_nudge/dtheta + loss_free_l = -(res_free * flat(f_free)).sum() + loss_nudge_l = -(res_nudge * flat(f_nudge)).sum() + + ep_loss_l = (loss_nudge_l - loss_free_l) / beta + ep_loss_l.backward() + + # Head: CE on nudged h3 + head_loss = F.cross_entropy(model.out_head(h_nudged[3].detach()), y) + head_loss.backward() + + torch.nn.utils.clip_grad_norm_(list(model.parameters()), 1.0) + for o in all_opts: + o.step() + + for s in schedulers: + s.step() + if ep % 20 == 0: + print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) + + return model + + +# --------------------------------------------------------------------------- +# Diagnostics +# --------------------------------------------------------------------------- + +def compute_bp_grads(model, x, y): + """ + Compute BP gradients w.r.t. each hidden state h_l via autograd. + Returns list of grad tensors (same shape as h_l), and the hidden states. + """ + model.eval() + L = model.num_blocks + + # Re-run forward with requires_grad on intermediate activations + # We build the forward manually to hook into each h_l + h = [None] * L + inp = x + for l in range(L): + if l == 3: + inp = inp.flatten(1) if inp.dim() > 2 else inp + h[l] = model.blocks[l](inp.detach().requires_grad_(False)) + h[l] = h[l].detach().requires_grad_(True) + inp = h[l] + + logits = model.out_head(h[3]) + loss = F.cross_entropy(logits, y) + gs = torch.autograd.grad(loss, h, allow_unused=True) + return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h + + +def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nudge=0.05): + model.eval() + L = model.num_blocks + + # Grab one batch + for x, y in tel: + x, y = x.to(dev), y.to(dev) + break + + # BP gradients + bp_grads, h_bp = compute_bp_grads(model, x, y) + + # Credit signals depending on method + if method == 'ep': + with torch.no_grad(): + _, h_free = model(x, return_hidden=True) + h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) + credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] + else: + # For BP and DFA, use BP grads directly (BP self-cosine = 1 by definition) + credits = [flat(bp_grads[l]) for l in range(L)] + + bp_grads_flat = [flat(g) for g in bp_grads] + + # Gamma: cosine similarity between credit and BP grad + gammas = [] + for l in range(L): + g = cosine_similarity_batch(credits[l], bp_grads_flat[l]) + gammas.append(float(g)) + + # rho: perturbation correlation using forward_from + with torch.no_grad(): + _, hiddens = model(x, return_hidden=True) + + rhos = [] + for l in range(L): + h_l = flat(hiddens[l].detach()) # (B, d_l) + a_l = credits[l].detach() # (B, d_l) + + # forward_fn: perturbed flat h_l -> per-sample CE loss + # we need to run from layer l+1 onward + def make_forward_fn(layer_idx): + def forward_fn(h_flat): + """h_flat: (B, d_l) flat tensor at layer layer_idx output.""" + with torch.no_grad(): + # Reshape back to spatial if needed + c = h_flat + for i in range(layer_idx + 1, L): + if i == 3: + c = model.blocks[i](c.flatten(1) if c.dim() > 2 else c) + else: + # blocks 1,2 expect spatial input; but c here is flat + # only happens for i=1 (in_dim 32*32*32->spatial 32,32,32) + # and i=2 (64,16,16). Since layer_idx<i we reshape. + if layer_idx < 3: + # Reconstruct spatial shape from flat + shapes = [(32, 32, 32), (64, 16, 16), (128, 8, 8)] + C_s, H_s, W_s = shapes[i - 1] + c = c.view(c.size(0), C_s, H_s, W_s) + c = model.blocks[i](c) + if c.dim() > 2: + c = c.flatten(1) + logits = model.out_head(c) + return F.cross_entropy(logits, y, reduction='none') + return forward_fn + + rho = perturbation_correlation(h_l, a_l, make_forward_fn(l), epsilon=1e-3, M=16) + rhos.append(float(rho)) + + return { + 'Gamma': float(np.mean(gammas)), + 'rho': float(np.mean(rhos)), + 'gammas_per_layer': gammas, + 'rhos_per_layer': rhos, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + p = argparse.ArgumentParser(description='CNN baseline for CIFAR-10') + p.add_argument('--method', type=str, required=True, choices=['bp', 'dfa', 'ep']) + p.add_argument('--seed', type=int, required=True) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/cnn_baseline') + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--wd', type=float, default=0.01) + # EP hyperparameters + p.add_argument('--beta', type=float, default=0.5) + p.add_argument('--T_nudge', type=int, default=20) + p.add_argument('--alpha_nudge', type=float, default=0.05) + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + dev = torch.device(f'cuda:{args.gpu}') + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + trl, tel = get_cifar10() + model = SmallCNN().to(dev) + + print(f"[{args.method} s={args.seed}] Training CNN on CIFAR-10 for {args.epochs} epochs...", flush=True) + + if args.method == 'bp': + model = train_bp(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd) + elif args.method == 'dfa': + model = train_dfa(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd) + elif args.method == 'ep': + model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, + beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge) + + acc = evaluate(model, tel, dev) + diag = compute_diagnostics(model, tel, dev, args.method, + beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge) + + # Save checkpoint + ckpt_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.pt') + torch.save(model.state_dict(), ckpt_path) + + result = { + 'method': args.method, + 'seed': args.seed, + 'acc': float(acc), + 'Gamma': diag['Gamma'], + 'rho': diag['rho'], + 'gammas_per_layer': diag['gammas_per_layer'], + 'rhos_per_layer': diag['rhos_per_layer'], + 'epochs': args.epochs, + 'lr': args.lr, + 'wd': args.wd, + 'beta': args.beta, + 'T_nudge': args.T_nudge, + 'alpha_nudge': args.alpha_nudge, + } + + json_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json') + with open(json_path, 'w') as f: + json.dump(result, f, indent=2, default=float) + + print( + f"[{args.method} s={args.seed}] acc={acc:.4f} " + f"Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}", + flush=True, + ) + print(f" gammas_per_layer={[f'{g:.4f}' for g in diag['gammas_per_layer']]}", flush=True) + print(f" rhos_per_layer ={[f'{r:.4f}' for r in diag['rhos_per_layer']]}", flush=True) + print(f" Saved: {json_path}", flush=True) + + +if __name__ == '__main__': + main() |
