""" CNN baseline for CIFAR-10: BP / DFA / EP on a small ConvNet. One method+seed per invocation for clean process isolation. Architecture: Conv2d(3,32,3,padding=1) -> ReLU Conv2d(32,64,3,padding=1) -> ReLU -> MaxPool(2) [32->16] Conv2d(64,128,3,padding=1) -> ReLU -> MaxPool(2) [16->8] flatten -> FC(128*8*8=8192, 256) -> ReLU -> FC(256, 10) Blocks (for local update): block 0 : Conv1 (Conv2d 3->32) block 1 : Conv2 (Conv2d 32->64) + MaxPool block 2 : Conv3 (Conv2d 64->128) + MaxPool block 3 : FC1 (Linear 8192->256) block 4 : FC2 (Linear 256->10) -- output head, always trained with loss Hidden states (post-activation, for credit): h0 : (B, 32, 32, 32) after Conv1+ReLU h1 : (B, 64, 16, 16) after Conv2+ReLU+MaxPool h2 : (B, 128, 8, 8) after Conv3+ReLU+MaxPool h3 : (B, 256) after flatten+FC1+ReLU DFA: flatten each h_l to (B, d_l), random feedback B_l: (d_l, 10) EP: energy E = sum_l 0.5 ||h_{l+1} - F_l(h_l)||^2 adapted for CNN Usage: python cnn_baseline.py --method bp --seed 42 --gpu 0 Output: results/cnn_baseline/{method}_s{seed}.json + .pt checkpoint """ import os, sys, json, argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation from models.state_bridge import StateBridgeNet from models.value_net import ValueNet, create_ema_model, update_ema import torchvision, torchvision.transforms as transforms # --------------------------------------------------------------------------- # Data # --------------------------------------------------------------------------- def get_cifar10(bs=128): tt = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trl = DataLoader( torchvision.datasets.CIFAR10('./data', True, download=True, transform=tt), bs, True, num_workers=4, pin_memory=True) tel = DataLoader( torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), bs, False, num_workers=4, pin_memory=True) return trl, tel # --------------------------------------------------------------------------- # Model # --------------------------------------------------------------------------- class SmallCNN(nn.Module): """ A small 3-conv CNN for CIFAR-10. Blocks (nn.Module list, mirrors the 5-block treatment): blocks[0] : Conv1 layer (Conv2d 3->32, BN, ReLU) blocks[1] : Conv2 layer (Conv2d 32->64, BN, ReLU, MaxPool) blocks[2] : Conv3 layer (Conv2d 64->128, BN, ReLU, MaxPool) blocks[3] : FC1 layer (Linear 8192->256, ReLU) out_head : FC2 layer (Linear 256->10) forward(x, return_hidden=False): returns logits, or (logits, [h0, h1, h2, h3]) when return_hidden=True. h_l are post-activation tensors; h3 is (B,256) flat. """ # flat dim of each hidden state FLAT_DIMS = [32 * 32 * 32, 64 * 16 * 16, 128 * 8 * 8, 256] NUM_BLOCKS = 4 # conv1, conv2, conv3, fc1 (out_head is separate) def __init__(self): super().__init__() self.blocks = nn.ModuleList([ # block 0: Conv1 nn.Sequential( nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), ), # block 1: Conv2 + MaxPool nn.Sequential( nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2), ), # block 2: Conv3 + MaxPool nn.Sequential( nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2), ), # block 3: FC1 nn.Sequential( nn.Linear(128 * 8 * 8, 256), nn.ReLU(inplace=True), ), ]) self.out_head = nn.Linear(256, 10) self.num_blocks = self.NUM_BLOCKS self.flat_dims = self.FLAT_DIMS def forward(self, x, return_hidden=False): """ x: (B, 3, 32, 32) Returns logits (B,10), optionally with list of 4 hidden states. h0: (B,32,32,32) h1: (B,64,16,16) h2: (B,128,8,8) h3: (B,256) """ h0 = self.blocks[0](x) # (B, 32, 32, 32) h1 = self.blocks[1](h0) # (B, 64, 16, 16) h2 = self.blocks[2](h1) # (B, 128, 8, 8) h3 = self.blocks[3](h2.flatten(1)) # (B, 256) logits = self.out_head(h3) # (B, 10) if return_hidden: return logits, [h0, h1, h2, h3] return logits def forward_from(self, h, layer_idx): """ Run the network from hidden state h at layer `layer_idx` to logits. layer_idx in {0, 1, 2, 3} (0=after block0, 3=after block3). h should be the post-activation tensor at that layer. """ c = h for i in range(layer_idx + 1, self.num_blocks): if i == 3: c = self.blocks[i](c.flatten(1) if c.dim() > 2 else c) else: c = self.blocks[i](c) if c.dim() > 2: c = c.flatten(1) logits = self.out_head(c if c.dim() == 2 else c.flatten(1)) return logits def evaluate(model, loader, dev): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in loader: x, y = x.to(dev), y.to(dev) correct += (model(x).argmax(1) == y).sum().item() total += x.size(0) return correct / total # --------------------------------------------------------------------------- # Helper: flatten hidden state for credit computation # --------------------------------------------------------------------------- def flat(h): """Flatten spatial dims: (B, C, H, W) -> (B, C*H*W) or (B, D) -> (B, D).""" return h.flatten(1) if h.dim() > 2 else h # --------------------------------------------------------------------------- # Training: BP # --------------------------------------------------------------------------- def train_bp(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01): opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for x, y in trl: x, y = x.to(dev), y.to(dev) F.cross_entropy(model(x), y).backward() opt.step() opt.zero_grad() sch.step() if ep % 20 == 0: print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) return model # --------------------------------------------------------------------------- # Training: DFA # --------------------------------------------------------------------------- def train_dfa(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01): """ Direct Feedback Alignment for CNN. For each block l, a random matrix B_l: (flat_dim_l, 10) maps the global error signal e_T (softmax-CE gradient at output) back to the hidden space. The local surrogate loss is: L_l = < F_l(h_{l-1}), a_l / ||a_l||_rms > where a_l = B_l @ e_T (flattened credit, then reshaped if needed). The out_head is trained with standard cross-entropy on the final hidden state. """ L = model.num_blocks # 4 blocks (conv1, conv2, conv3, fc1) C = 10 flat_dims = model.flat_dims # [32768, 16384, 8192, 256] # Random feedback matrices (fixed, not trained) Bs = [torch.randn(flat_dims[l], C, device=dev) / np.sqrt(C) for l in range(L)] # Per-block optimizers + head optimizer block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) all_opts = block_opts + [head_opt] schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts] for ep in range(1, epochs + 1): model.train() for x, y in trl: x, y = x.to(dev), y.to(dev) B = x.size(0) # Forward pass (no grad) to get hidden states and global error with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) probs = logits.softmax(-1) # (B, 10) e_T = probs.clone() e_T[torch.arange(B), y] -= 1.0 # (B, 10) # --- Train out_head with standard CE on detached h3 --- h3_det = hiddens[3].detach() ce_loss = F.cross_entropy(model.out_head(h3_det), y) head_opt.zero_grad() ce_loss.backward() head_opt.step() # --- Train each block with DFA local surrogate --- # For conv blocks (l=0,1,2) we need to re-run the block forward # starting from the *previous* hidden state. # The "input" to block l is: # l=0: x (raw input image) # l=1: h0 # l=2: h1 # l=3: h2 (flattened) inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()] for l in range(L): # Compute DFA credit signal (flattened) a_l_flat = (e_T @ Bs[l].T).detach() # (B, flat_dim_l) rms = (a_l_flat ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_l_norm = a_l_flat / rms # (B, flat_dim_l) # Forward through block l with grad inp = inputs[l].detach() if l == 3: out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) else: out_l = model.blocks[l](inp) # Local surrogate: (summed over spatial, averaged over batch) out_flat = flat(out_l) # (B, flat_dim_l) local_loss = (out_flat * a_l_norm).sum(-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() for s in schedulers: s.step() if ep % 20 == 0: print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) return model # --------------------------------------------------------------------------- # Training: EP (Equilibrium Propagation adapted for CNN) # --------------------------------------------------------------------------- def ep_energy_cnn(model, hiddens, x): """ CNN EP energy: E = sum_l 0.5 ||h_l - F_l(inp_l)||^2 (flattened). hiddens[0] = h0 (B,32,32,32) -- target for block 0 applied to x hiddens[1] = h1 (B,64,16,16) -- target for block 1 applied to h0 hiddens[2] = h2 (B,128,8,8) -- target for block 2 applied to h1 hiddens[3] = h3 (B,256) -- target for block 3 applied to h2.flatten """ inputs = [x, hiddens[0], hiddens[1], hiddens[2]] E = 0.0 for l in range(model.num_blocks): inp = inputs[l] if l == 3: pred = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) else: pred = model.blocks[l](inp) # Compare flattened versions pred_f = flat(pred) h_f = flat(hiddens[l]) residual = h_f - pred_f # (B, d_l) E = E + 0.5 * (residual ** 2).sum(-1) # (B,) return E def ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge): """ Nudged phase: minimize E(h) + beta * CE(out_head(h3), y) w.r.t. h0, h1, h2, h3 (all free hidden states). x is fixed (pixel input, not a hidden state). """ L = model.num_blocks # Initialise from free phase h_nudged = [h.clone().detach() for h in h_free] for i in range(L): h_nudged[i].requires_grad_(True) inner_opt = optim.SGD(h_nudged, lr=alpha_nudge) for _ in range(T_nudge): E = ep_energy_cnn(model, h_nudged, x) # (B,) logits = model.out_head(h_nudged[3]) # (B, 10) C_loss = F.cross_entropy(logits, y, reduction='none') # (B,) total = (E + beta * C_loss).mean() inner_opt.zero_grad() total.backward() inner_opt.step() return [h.detach() for h in h_nudged] def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, beta=0.5, T_nudge=20, alpha_nudge=0.05): """ Equilibrium Propagation for the small CNN. Weight update rule: Δθ ∝ (dE_nudged/dθ - dE_free/dθ) / beta For the out_head: standard CE on nudged output (no dE/dtheta_head term). """ L = model.num_blocks block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) all_opts = block_opts + [head_opt] schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts] for ep in range(1, epochs + 1): model.train() for x, y in trl: x, y = x.to(dev), y.to(dev) # Free phase: standard forward pass with torch.no_grad(): _, h_free = model(x, return_hidden=True) # Nudged phase h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) # Zero all grads for o in all_opts: o.zero_grad() # EP weight update per block: # dE/dtheta_l = -residual_l * dF_l/dtheta_l (same as MLP EP) inputs_free = [x, h_free[0].detach(), h_free[1].detach(), h_free[2].detach()] inputs_nudge = [x, h_nudged[0].detach(), h_nudged[1].detach(), h_nudged[2].detach()] for l in range(L): inp_f = inputs_free[l].detach() inp_n = inputs_nudge[l].detach() if l == 3: f_free = model.blocks[l](inp_f.flatten(1) if inp_f.dim() > 2 else inp_f) f_nudge = model.blocks[l](inp_n.flatten(1) if inp_n.dim() > 2 else inp_n) else: f_free = model.blocks[l](inp_f) f_nudge = model.blocks[l](inp_n) # residuals (detached target - computed output) res_free = (flat(h_free[l]).detach() - flat(f_free).detach()) # (B, d_l) res_nudge = (flat(h_nudged[l]).detach() - flat(f_nudge).detach()) # dE/dtheta = -(res * dF/dtheta) => gradient via autograd trick # loss_free_l = -(res_free * f_l_free).sum() gives dE_free/dtheta # loss_nudge_l = -(res_nudge * f_l_nudge).sum() gives dE_nudge/dtheta loss_free_l = -(res_free * flat(f_free)).sum() loss_nudge_l = -(res_nudge * flat(f_nudge)).sum() ep_loss_l = (loss_nudge_l - loss_free_l) / beta ep_loss_l.backward() # Head: CE on nudged h3 head_loss = F.cross_entropy(model.out_head(h_nudged[3].detach()), y) head_loss.backward() torch.nn.utils.clip_grad_norm_(list(model.parameters()), 1.0) for o in all_opts: o.step() for s in schedulers: s.step() if ep % 20 == 0: print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) return model # --------------------------------------------------------------------------- # Training: State Bridge # --------------------------------------------------------------------------- def train_state_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb=3e-4): """ State Bridge for CNN. StateBridgeNet G_psi(h_l_flat, t_l, s) -> predicted h3 (256-dim terminal state). s = e_T (10-dim softmax error). Credit: a_l = grad_{h_l_flat} CE(out_head(SB(h_l_flat, t_l, s)), y). Local update: . """ L = model.num_blocks # 4 C = 10 flat_dims = model.flat_dims # [32768, 16384, 8192, 256] d_terminal = 256 # h3 is the terminal hidden state # One SB net per layer: MLP from flat_dim_l + time_embed + s_dim -> 256 from models.value_net import SinusoidalTimeEmbed class CNNStateBridge(nn.Module): def __init__(self, in_dim, out_dim, s_dim, te_dim=32): super().__init__() self.ln = nn.LayerNorm(in_dim) self.te = SinusoidalTimeEmbed(te_dim) total = in_dim + te_dim + s_dim self.net = nn.Sequential(nn.Linear(total, 256), nn.GELU(), nn.Linear(256, 256), nn.GELU(), nn.Linear(256, out_dim)) def forward(self, h, t, s): return self.net(torch.cat([self.ln(h), self.te(t), s], dim=-1)) state_preds = nn.ModuleList([ CNNStateBridge(flat_dims[l], d_terminal, C).to(dev) for l in range(L) ]) block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) state_opts = [optim.Adam(state_preds[l].parameters(), lr=lr_fb) for l in range(L)] all_main_opts = block_opts + [head_opt] schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_main_opts] for ep in range(1, epochs + 1): model.train() for sp in state_preds: sp.train() for x, y in trl: x, y = x.to(dev), y.to(dev) B = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) probs = logits.softmax(-1) e_T = probs.clone() e_T[torch.arange(B), y] -= 1.0 s = e_T.detach() h3_det = hiddens[3].detach() # (B, 256) terminal hidden state # --- Train each state predictor: G_psi_l(h_l_flat, t_l, s) -> h3 --- for l in range(L): h_l_flat = flat(hiddens[l]).detach() t_l = torch.full((B,), l / L, device=dev) pred_h3 = state_preds[l](h_l_flat, t_l, s) target = h3_det target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss = (((pred_h3 - target) / target_norm) ** 2).sum(dim=-1).mean() state_opts[l].zero_grad() state_loss.backward() state_opts[l].step() # --- Compute credits: a_l = grad_{h_l_flat} CE(out_head(SB(h_l_flat, t_l, s)), y) --- credits = [] for l in range(L): h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True) t_l = torch.full((B,), l / L, device=dev) pred_h3 = state_preds[l](h_l_flat_req, t_l, s) pred_logits = model.out_head(pred_h3) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_flat_req, create_graph=False)[0] credits.append(a_l.detach()) # (B, flat_dim_l) # --- Train out_head with CE on detached h3 --- ce_loss = F.cross_entropy(model.out_head(h3_det), y) head_opt.zero_grad() ce_loss.backward() head_opt.step() # --- Train each block with local surrogate --- inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()] for l in range(L): a_l = credits[l] rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_l_norm = a_l / rms inp = inputs[l].detach() if l == 3: out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) else: out_l = model.blocks[l](inp) out_flat = flat(out_l) local_loss = (out_flat * a_l_norm).sum(-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() for s in schedulers: s.step() if ep % 20 == 0: print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True) return model, state_preds # --------------------------------------------------------------------------- # Training: Credit Bridge # --------------------------------------------------------------------------- def train_credit_bridge(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01, lr_fb=3e-4, lam=0.1, K=4, sigma_bridge=0.01, ema_momentum=0.99, term_grad_weight=0.1): """ Credit Bridge for CNN. ValueNet V(h_l_flat, t_l, s) -> scalar. Credit: a_l = grad_{h_l_flat} V. Training: terminal boundary + bridge consistency. DFA warmup for first 20% epochs. """ L = model.num_blocks # 4 C = 10 flat_dims = model.flat_dims # [32768, 16384, 8192, 256] warmup_epochs = max(1, epochs // 5) # One ValueNet per layer (each takes flat_dim_l as h input) value_nets = nn.ModuleList([ ValueNet(d_hidden=flat_dims[l], s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(dev) for l in range(L) ]) value_nets_ema = nn.ModuleList([create_ema_model(value_nets[l]) for l in range(L)]) # DFA fallback matrices Bs_fallback = [torch.randn(flat_dims[l], C, device=dev) / np.sqrt(C) for l in range(L)] block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)] head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) value_opts = [optim.Adam(value_nets[l].parameters(), lr=lr_fb) for l in range(L)] all_main_opts = block_opts + [head_opt] schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_main_opts] print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)") for ep in range(1, epochs + 1): model.train() for vn in value_nets: vn.train() if ep <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (ep - warmup_epochs) / max(1, warmup_epochs)) for x, y in trl: x, y = x.to(dev), y.to(dev) B = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) probs = logits.softmax(-1) e_T = probs.clone() e_T[torch.arange(B), y] -= 1.0 s = e_T.detach() true_loss = F.cross_entropy(logits, y, reduction='none').detach() h3_det = flat(hiddens[3]).detach() # (B, 256) terminal # --- Train value nets (always) --- for l in range(L): h_l_flat = flat(hiddens[l]).detach() t_l = torch.full((B,), l / L, device=dev) t_l_next = torch.full((B,), (l + 1) / L, device=dev) # Terminal boundary loss (only for last layer) if l == L - 1: V_l = value_nets[l](h_l_flat, t_l, s) loss_term = ((V_l - true_loss) ** 2).mean() else: # Bridge consistency: V_l ~ -lam * log E[exp(-V_{l+1}/lam)] V_l = value_nets[l](h_l_flat, t_l, s) h_next_flat = flat(hiddens[l + 1]).detach() with torch.no_grad(): log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn_like(h_next_flat) V_next = value_nets_ema[l + 1](h_next_flat + noise, t_l_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K)) loss_term = ((V_l - V_target.detach()) ** 2).mean() value_opts[l].zero_grad() loss_term.backward() torch.nn.utils.clip_grad_norm_(value_nets[l].parameters(), 1.0) value_opts[l].step() update_ema(value_nets[l], value_nets_ema[l], ema_momentum) # --- Compute credits --- cb_credits = [] for l in range(L): h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True) t_l = torch.full((B,), l / L, device=dev) V_l = value_nets[l](h_l_flat_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_flat_req, create_graph=False)[0] cb_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] credits = [] for l in range(L): if credit_blend >= 1.0: a = cb_credits[l] elif credit_blend <= 0.0: a = dfa_credits[l] else: cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a = (credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)) credits.append(a) # --- Train out_head with CE on detached h3 --- ce_loss = F.cross_entropy(model.out_head(h3_det), y) head_opt.zero_grad() ce_loss.backward() head_opt.step() # --- Train each block with local surrogate --- inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()] for l in range(L): a_l = credits[l] rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_l_norm = a_l / rms inp = inputs[l].detach() if l == 3: out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp) else: out_l = model.blocks[l](inp) out_flat = flat(out_l) local_loss = (out_flat * a_l_norm).sum(-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() for s in schedulers: s.step() if ep % 20 == 0: phase = "warmup" if ep <= warmup_epochs else f"blend={credit_blend:.2f}" print(f" Ep {ep} ({phase}): acc={evaluate(model, tel, dev):.4f}", flush=True) return model, value_nets, value_nets_ema # --------------------------------------------------------------------------- # Diagnostics # --------------------------------------------------------------------------- def compute_bp_grads(model, x, y): """ Compute BP gradients w.r.t. each hidden state h_l via autograd. Returns list of grad tensors (same shape as h_l), and the hidden states. """ model.eval() L = model.num_blocks # Re-run forward with requires_grad on intermediate activations # We build the forward manually to hook into each h_l # Build forward graph keeping all h[l] connected so gradients flow through h = [None] * L inp = x for l in range(L): if l == 3: inp = inp.flatten(1) if inp.dim() > 2 else inp out = model.blocks[l](inp) h[l] = out # keep in graph (no detach between layers) inp = out logits = model.out_head(h[3]) loss = F.cross_entropy(logits, y) # Compute gradient w.r.t. all hidden states gs = torch.autograd.grad(loss, h, allow_unused=True) return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nudge=0.05, state_preds=None, value_nets=None): model.eval() if state_preds is not None: for sp in state_preds: sp.eval() if value_nets is not None: for vn in value_nets: vn.eval() L = model.num_blocks # Grab one batch for x, y in tel: x, y = x.to(dev), y.to(dev) break B = x.size(0) # BP gradients bp_grads, h_bp = compute_bp_grads(model, x, y) # Credit signals depending on method if method == 'ep': with torch.no_grad(): _, h_free = model(x, return_hidden=True) h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) # Negate: EP nudge moves h toward lower loss, opposite to BP grad direction credits = [-flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] elif method in ('state_bridge', 'credit_bridge'): with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) probs = logits.softmax(-1) e_T = probs.clone() e_T[torch.arange(B), y] -= 1.0 s = e_T.detach() credits = [] for l in range(L): h_l_flat_req = flat(hiddens[l]).detach().requires_grad_(True) t_l = torch.full((B,), l / L, device=dev) if method == 'state_bridge': pred_h3 = state_preds[l](h_l_flat_req, t_l, s) pred_logits = model.out_head(pred_h3) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_flat_req, create_graph=False)[0] else: # credit_bridge V_l = value_nets[l](h_l_flat_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_flat_req, create_graph=False)[0] credits.append(a_l.detach()) else: # For BP and DFA, use BP grads directly (BP self-cosine = 1 by definition) credits = [flat(bp_grads[l]) for l in range(L)] bp_grads_flat = [flat(g) for g in bp_grads] # Gamma: cosine similarity between credit and BP grad gammas = [] for l in range(L): g = cosine_similarity_batch(credits[l], bp_grads_flat[l]) gammas.append(float(g)) # rho: perturbation correlation using forward_from with torch.no_grad(): _, hiddens = model(x, return_hidden=True) rhos = [] for l in range(L): h_l = flat(hiddens[l].detach()) # (B, d_l) a_l = credits[l].detach() # (B, d_l) # forward_fn: perturbed flat h_l -> per-sample CE loss # we need to run from layer l+1 onward def make_forward_fn(layer_idx): def forward_fn(h_flat): """h_flat: (B, d_l) flat tensor at layer layer_idx output.""" with torch.no_grad(): # Reshape back to spatial if needed c = h_flat for i in range(layer_idx + 1, L): if i == 3: c = model.blocks[i](c.flatten(1) if c.dim() > 2 else c) else: # blocks 1,2 expect spatial input; but c here is flat # only happens for i=1 (in_dim 32*32*32->spatial 32,32,32) # and i=2 (64,16,16). Since layer_idx 2: c = c.flatten(1) logits = model.out_head(c) return F.cross_entropy(logits, y, reduction='none') return forward_fn rho = perturbation_correlation(h_l, a_l, make_forward_fn(l), epsilon=1e-3, M=16) rhos.append(float(rho)) return { 'Gamma': float(np.mean(gammas)), 'rho': float(np.mean(rhos)), 'gammas_per_layer': gammas, 'rhos_per_layer': rhos, } # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): p = argparse.ArgumentParser(description='CNN baseline for CIFAR-10') p.add_argument('--method', type=str, required=True, choices=['bp', 'dfa', 'ep', 'state_bridge', 'credit_bridge']) p.add_argument('--seed', type=int, required=True) p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/cnn_baseline') p.add_argument('--epochs', type=int, default=100) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--wd', type=float, default=0.01) # EP hyperparameters p.add_argument('--beta', type=float, default=0.5) p.add_argument('--T_nudge', type=int, default=20) p.add_argument('--alpha_nudge', type=float, default=0.05) # SB/CB hyperparameters p.add_argument('--lr_fb', type=float, default=3e-4, help='Learning rate for SB/CB feedback nets') p.add_argument('--lam', type=float, default=0.1, help='CB soft-min temperature') p.add_argument('--K', type=int, default=4, help='CB bridge consistency samples') p.add_argument('--sigma_bridge', type=float, default=0.01, help='CB bridge noise std') p.add_argument('--ema_momentum', type=float, default=0.99, help='CB EMA momentum') p.add_argument('--term_grad_weight', type=float, default=0.1, help='CB terminal gradient matching weight') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) dev = torch.device(f'cuda:{args.gpu}') torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) trl, tel = get_cifar10() model = SmallCNN().to(dev) print(f"[{args.method} s={args.seed}] Training CNN on CIFAR-10 for {args.epochs} epochs...", flush=True) state_preds = None value_nets = None if args.method == 'bp': model = train_bp(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd) elif args.method == 'dfa': model = train_dfa(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd) elif args.method == 'ep': model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge) elif args.method == 'state_bridge': model, state_preds = train_state_bridge( model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, lr_fb=args.lr_fb) elif args.method == 'credit_bridge': model, value_nets, _ = train_credit_bridge( model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd, lr_fb=args.lr_fb, lam=args.lam, K=args.K, sigma_bridge=args.sigma_bridge, ema_momentum=args.ema_momentum, term_grad_weight=args.term_grad_weight) acc = evaluate(model, tel, dev) diag = compute_diagnostics(model, tel, dev, args.method, beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge, state_preds=state_preds, value_nets=value_nets) # Save checkpoint ckpt_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.pt') torch.save(model.state_dict(), ckpt_path) result = { 'method': args.method, 'seed': args.seed, 'acc': float(acc), 'Gamma': diag['Gamma'], 'rho': diag['rho'], 'gammas_per_layer': diag['gammas_per_layer'], 'rhos_per_layer': diag['rhos_per_layer'], 'epochs': args.epochs, 'lr': args.lr, 'wd': args.wd, 'beta': args.beta, 'T_nudge': args.T_nudge, 'alpha_nudge': args.alpha_nudge, } json_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json') with open(json_path, 'w') as f: json.dump(result, f, indent=2, default=float) print( f"[{args.method} s={args.seed}] acc={acc:.4f} " f"Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}", flush=True, ) print(f" gammas_per_layer={[f'{g:.4f}' for g in diag['gammas_per_layer']]}", flush=True) print(f" rhos_per_layer ={[f'{r:.4f}' for r in diag['rhos_per_layer']]}", flush=True) print(f" Saved: {json_path}", flush=True) if __name__ == '__main__': main()