""" Phase B: Deep Residual MLP on CIFAR-10. Compare BP, DFA, State Bridge, Credit Bridge. CRITICAL CONSTRAINT: No hidden BP anchor for non-BP methods. All block updates use detached hidden states and local surrogates. """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy import time sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test, offline_bp_cosine, feature_drift ) def get_data(dataset='cifar10', batch_size=128): if dataset == 'cifar10': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) input_dim = 32 * 32 * 3 num_classes = 10 elif dataset == 'cifar100': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ]) trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) input_dim = 32 * 32 * 3 num_classes = 100 elif dataset == 'fashionmnist': transform_train = transforms.Compose([ transforms.RandomCrop(28, padding=2), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,)), ]) trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test) input_dim = 28 * 28 num_classes = 10 else: raise ValueError(f"Unknown dataset: {dataset}") train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader, input_dim, num_classes def evaluate(model, test_loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return correct / total # ============================================================================= # BP Baseline # ============================================================================= def train_bp(model, train_loader, test_loader, device, args): """Standard end-to-end backprop training.""" optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) log = {'train_loss': [], 'train_acc': [], 'test_acc': []} for epoch in range(1, args.epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) if getattr(args, 'random_targets', False): y = torch.randint(0, args.num_classes, y.shape, device=device) logits = model(x) loss = F.cross_entropy(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * x.size(0) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) scheduler.step() train_loss = total_loss / total train_acc = correct / total test_acc = evaluate(model, test_loader, device) log['train_loss'].append(train_loss) log['train_acc'].append(train_acc) log['test_acc'].append(test_acc) if epoch % 10 == 0 or epoch == 1: print(f" [BP] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") return log # ============================================================================= # DFA Baseline # ============================================================================= def train_dfa(model, train_loader, test_loader, device, args): """ DFA training with fixed random feedback matrices. Each block updated with local surrogate: L_l = . Output head updated with exact CE gradient (h_L detached). Embedding updated via DFA credit at h_0. """ d = model.d_hidden num_classes = args.num_classes L = model.num_blocks # Fixed random feedback matrices, one per block Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] # Separate optimizers block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=args.lr, weight_decay=args.wd ) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) log = {'train_loss': [], 'train_acc': [], 'test_acc': []} for epoch in range(1, args.epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) if getattr(args, 'random_targets', False): y = torch.randint(0, args.num_classes, y.shape, device=device) batch = x.size(0) # Forward pass (no grad for hidden states) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) # e_T = softmax(logits) - one_hot(y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 # (batch, num_classes) # 1. Update output head: exact CE gradient, h_L detached hL_det = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # 2. Update each block with DFA local surrogate for l in range(L): h_l = hiddens[l].detach() # DFA credit: a_{l+1} = B_l @ e_T^T -> (d, batch) -> transpose a_dfa = (e_T @ Bs[l].T).detach() # (batch, d) = (batch, C) @ (C, d) # Normalize rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_dfa_norm = a_dfa / rms # Local surrogate f_l = model.blocks[l](h_l) local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean() if getattr(args, 'penalty_lam', 0.0) > 0.0: local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() # 3. Update embedding with DFA credit at h_0 a_0_dfa = (e_T @ Bs[0].T).detach() rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_0_norm = a_0_dfa / rms_0 h0 = model.embed(x) embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for s in all_schedulers: s.step() train_loss = total_loss / total train_acc = correct / total test_acc = evaluate(model, test_loader, device) log['train_loss'].append(train_loss) log['train_acc'].append(train_acc) log['test_acc'].append(test_acc) if epoch % 10 == 0 or epoch == 1: print(f" [DFA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") return log, Bs # ============================================================================= # Vanilla FA (Lillicrap 2016) # ============================================================================= def train_fa(model, train_loader, test_loader, device, args): """ Vanilla Feedback Alignment (Lillicrap et al. 2016). Unlike DFA (which projects output error directly to each layer via a_l = B_l^T @ e_T), FA propagates credit sequentially backward through the block stack using fixed random d×d feedback matrices: a_L = exact gradient at h_L through out_head + out_ln a_l = B_l @ a_{l+1} (random d×d replaces block Jacobian transpose) Each block is updated with the same local loss as DFA: . """ d = model.d_hidden num_classes = args.num_classes L = model.num_blocks # Fixed random feedback matrices: d × d (one per block). # These replace the transpose of the block Jacobian dF_l/dh_l in the # backward pass. Contrast with DFA's B_l which are d × num_classes. Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)] # Same optimizer structure as DFA block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=args.lr, weight_decay=args.wd ) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) log = {'train_loss': [], 'train_acc': [], 'test_acc': []} for epoch in range(1, args.epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) if getattr(args, 'random_targets', False): y = torch.randint(0, args.num_classes, y.shape, device=device) batch = x.size(0) # Forward pass with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) # 1. Update output head (exact CE gradient, h_L detached) hL_det = hiddens[-1].detach().requires_grad_(True) logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Exact gradient at h_L — FA's starting credit signal a_credit = hL_det.grad.detach() # (batch, d) # 2. Update each block with FA credit (backward sequential) for l in range(L - 1, -1, -1): h_l = hiddens[l].detach() # Normalize credit rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a_credit / rms # Local surrogate (same form as DFA) f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() if getattr(args, 'penalty_lam', 0.0) > 0.0: local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() # Propagate credit backward: FA replaces block Jacobian^T with B_l a_credit = (a_credit @ Bs[l]).detach() # 3. Update embedding with FA credit at h_0 rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_0_norm = a_credit / rms_0 h0 = model.embed(x) embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for s in all_schedulers: s.step() train_loss = total_loss / total train_acc = correct / total test_acc = evaluate(model, test_loader, device) log['train_loss'].append(train_loss) log['train_acc'].append(train_acc) log['test_acc'].append(test_acc) if epoch % 10 == 0 or epoch == 1: print(f" [FA] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, test={test_acc:.4f}") return log, Bs # ============================================================================= # State Bridge # ============================================================================= def train_state_bridge(model, train_loader, test_loader, device, args): """ State Bridge: predict terminal h_L from (h_l, t_l, s), derive credit as a_l = grad_{h_l} CE(W_out * LN(G_psi(h_l, t_l, s)), y). """ d = model.d_hidden num_classes = args.num_classes L = model.num_blocks state_pred = StateBridgeNet( d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 ).to(device) block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=args.lr, weight_decay=args.wd ) state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []} for epoch in range(1, args.epochs + 1): model.train() state_pred.train() total_loss, correct, total = 0, 0, 0 total_se = 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) if getattr(args, 'random_targets', False): y = torch.randint(0, args.num_classes, y.shape, device=device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() # Train state predictor: G_psi(h_l, t_l, s) -> h_L # Predict the *residual* from h_l to h_L for numerical stability state_loss = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) # Target: h_L (use normalized MSE for stability) target = hL_det target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean() state_loss = state_loss / L state_opt.zero_grad() state_loss.backward() state_opt.step() total_se += state_loss.item() * batch # Compute credits: a_l = grad_{h_l} CE(out_head(LN(G(h_l, t_l, s))), y) credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) pred_logits = model.out_head(model.out_ln(pred_hL)) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] credits.append(a_l.detach()) # Update output head logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() if getattr(args, 'penalty_lam', 0.0) > 0.0: local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() # Update embedding with credit at layer 0 a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_0_norm = a_0 / rms_0 h0 = model.embed(x) embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_schedulers: sch.step() train_loss = total_loss / total train_acc = correct / total test_acc = evaluate(model, test_loader, device) se = total_se / total log['train_loss'].append(train_loss) log['train_acc'].append(train_acc) log['test_acc'].append(test_acc) log['state_pred_error'].append(se) if epoch % 10 == 0 or epoch == 1: print(f" [SB] Epoch {epoch}: loss={train_loss:.4f}, train={train_acc:.4f}, " f"test={test_acc:.4f}, state_err={se:.4f}") return log, state_pred # ============================================================================= # Credit Bridge # ============================================================================= def train_credit_bridge(model, train_loader, test_loader, device, args): """ Credit Bridge: learn V_phi(h_l, t_l, s) -> scalar value. Credit: a_l = grad_{h_l} V_phi. Training: terminal boundary + bridge consistency + terminal gradient matching. The terminal gradient is local (output layer only), NOT hidden BP. Uses a warmup phase: first warmup_epochs, only train value net + output head, then start using credit bridge signals to update blocks. During warmup, blocks get DFA-style updates as a fallback. """ d = model.d_hidden num_classes = args.num_classes L = model.num_blocks warmup_epochs = max(1, args.epochs // 5) # 20% warmup value_net = ValueNet( d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 ).to(device) value_net_ema = create_ema_model(value_net) # DFA fallback matrices for warmup Bs_fallback = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=args.lr, weight_decay=args.wd ) value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) lam = args.lam K_samples = args.K sigma_bridge = args.sigma_bridge ema_momentum = args.ema_momentum term_grad_weight = args.term_grad_weight log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []} print(f" [CB] Warmup phase: {warmup_epochs} epochs (DFA fallback + value net training)") for epoch in range(1, args.epochs + 1): model.train() value_net.train() total_loss, correct, total = 0, 0, 0 total_vloss = 0 # Blend factor: 0 during warmup, linearly increases to 1 after warmup if epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) if getattr(args, 'random_targets', False): y = torch.randint(0, args.num_classes, y.shape, device=device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() # ---- Train value net (always) ---- t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() # Terminal gradient matching loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req2)) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() # Bridge consistency loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_l_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K_samples): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_l_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) total_vloss += value_loss.item() * batch # ---- Compute credits ---- # Credit bridge credits cb_credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V_l = value_net(h_l_det, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] cb_credits.append(a_l.detach()) # DFA fallback credits dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] # Blend credits credits = [] for l in range(L): if credit_blend >= 1.0: a = cb_credits[l] elif credit_blend <= 0.0: a = dfa_credits[l] else: # Normalize both before blending cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) credits.append(a) # ---- Update output head ---- logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # ---- Update blocks ---- for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() if getattr(args, 'penalty_lam', 0.0) > 0.0: local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() # ---- Update embedding ---- a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_0_norm = a_0 / rms_0 h0 = model.embed(x) embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_schedulers: sch.step() train_loss = total_loss / total train_acc = correct / total test_acc = evaluate(model, test_loader, device) vloss = total_vloss / total log['train_loss'].append(train_loss) log['train_acc'].append(train_acc) log['test_acc'].append(test_acc) log['value_loss'].append(vloss) if epoch % 10 == 0 or epoch == 1: phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" print(f" [CB] Epoch {epoch} ({phase}): loss={train_loss:.4f}, train={train_acc:.4f}, " f"test={test_acc:.4f}, vloss={vloss:.6f}") return log, value_net, value_net_ema # ============================================================================= # Diagnostics # ============================================================================= def compute_diagnostics(model, method_name, test_loader, device, args, value_net=None, state_predictor=None, dfa_Bs=None): """Compute all diagnostic metrics for a trained model.""" model.eval() if value_net is not None: value_net.eval() if state_predictor is not None: state_predictor.eval() d = model.d_hidden L = model.num_blocks num_classes = args.num_classes # Get one batch for diagnostics for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) break batch = x.size(0) # Forward with hidden states, need grad for BP cosine logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() loss_bp = F.cross_entropy(logits_bp, y) loss_bp.backward() bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} # Forward again without grad for clean hidden states with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() # Per-layer hidden norms (median across batch) and BP grad norms (per-sample L2, median) hidden_norms_per_layer = [float(hiddens[l].detach().norm(dim=-1).median().item()) for l in range(L + 1)] bp_grad_norms_per_layer = [float(bp_grads[l].norm(dim=-1).median().item()) for l in range(L + 1)] results = { 'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.001': [], '0.003': [], '0.01': []}, 'hidden_norms_per_layer': hidden_norms_per_layer, 'bp_grad_norms_per_layer': bp_grad_norms_per_layer, } # Pre-compute FA credits if needed (sequential backward from exact h_L gradient) _fa_credits = None if method_name == 'fa' and dfa_Bs is not None: hL_req = hiddens[L].detach().requires_grad_(True) logits_fa = model.out_head(model.out_ln(hL_req)) loss_fa = F.cross_entropy(logits_fa, y, reduction='sum') _fa_a_L = torch.autograd.grad(loss_fa, hL_req)[0].detach() _fa_credits = [None] * L _fa_credits[L - 1] = _fa_a_L for ll in range(L - 2, -1, -1): _fa_credits[ll] = (_fa_credits[ll + 1] @ dfa_Bs[ll + 1]).detach() for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) # Get credit if method_name == 'bp': a_l = bp_grads[l] elif method_name == 'dfa': a_l = (e_T @ dfa_Bs[l].T).detach() elif method_name == 'fa': a_l = _fa_credits[l] elif method_name == 'state_bridge': h_l_req = h_l.clone().requires_grad_(True) pred_hL = state_predictor(h_l_req, t_l, s) pred_logits = model.out_head(model.out_ln(pred_hL)) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() elif method_name == 'credit_bridge': h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() else: raise ValueError(f"Unknown method: {method_name}") # BP cosine bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) results['bp_cosine'].append(bp_cos) # Forward function for perturbation and nudging def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(model.out_ln(curr)) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) results['perturbation_rho'].append(rho) for eta in [0.001, 0.003, 0.01]: nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) results['nudging'][str(eta)].append(nud) return results # ============================================================================= # Main # ============================================================================= def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) all_results = {} for seed in args.seeds: print(f"\n{'='*60}") print(f"Seed {seed}") print(f"{'='*60}") torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) train_loader, test_loader, input_dim, num_classes = get_data(args.dataset, args.batch_size) args.num_classes = num_classes seed_results = {} methods_to_run = getattr(args, 'methods', ['bp', 'dfa', 'state_bridge', 'credit_bridge']) # ---- BP ---- if 'bp' in methods_to_run: print("\n--- BP ---") model_bp = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) init_bp = {n: p.clone().detach() for n, p in model_bp.named_parameters()} bp_log = train_bp(model_bp, train_loader, test_loader, device, args) bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args) bp_drift = feature_drift(init_bp, {n: p.detach() for n, p in model_bp.named_parameters()}) seed_results['bp'] = {'log': bp_log, 'diagnostics': bp_diag, 'drift': bp_drift} print(f" Final test acc: {bp_log['test_acc'][-1]:.4f}") # ---- DFA ---- if 'dfa' in methods_to_run: print("\n--- DFA ---") torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) model_dfa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) init_dfa = {n: p.clone().detach() for n, p in model_dfa.named_parameters()} dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args) dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs) dfa_drift = feature_drift(init_dfa, {n: p.detach() for n, p in model_dfa.named_parameters()}) seed_results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag, 'drift': dfa_drift} print(f" Final test acc: {dfa_log['test_acc'][-1]:.4f}") # ---- FA (vanilla Feedback Alignment, Lillicrap 2016) ---- if 'fa' in methods_to_run: print("\n--- FA ---") torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) model_fa = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) init_fa = {n: p.clone().detach() for n, p in model_fa.named_parameters()} fa_log, fa_Bs = train_fa(model_fa, train_loader, test_loader, device, args) fa_diag = compute_diagnostics(model_fa, 'fa', test_loader, device, args, dfa_Bs=fa_Bs) fa_drift = feature_drift(init_fa, {n: p.detach() for n, p in model_fa.named_parameters()}) seed_results['fa'] = {'log': fa_log, 'diagnostics': fa_diag, 'drift': fa_drift} print(f" Final test acc: {fa_log['test_acc'][-1]:.4f}") # ---- State Bridge ---- if 'state_bridge' in methods_to_run: print("\n--- State Bridge ---") torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) model_sb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) init_sb = {n: p.clone().detach() for n, p in model_sb.named_parameters()} sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args) sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args, state_predictor=state_pred) sb_drift = feature_drift(init_sb, {n: p.detach() for n, p in model_sb.named_parameters()}) seed_results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag, 'drift': sb_drift} print(f" Final test acc: {sb_log['test_acc'][-1]:.4f}") # ---- Credit Bridge ---- if 'credit_bridge' in methods_to_run: print("\n--- Credit Bridge ---") torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) model_cb = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) init_cb = {n: p.clone().detach() for n, p in model_cb.named_parameters()} cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args) cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args, value_net=vnet) cb_drift = feature_drift(init_cb, {n: p.detach() for n, p in model_cb.named_parameters()}) seed_results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag, 'drift': cb_drift} print(f" Final test acc: {cb_log['test_acc'][-1]:.4f}") all_results[seed] = seed_results # Save def serialize(obj): if isinstance(obj, dict): return {str(k): serialize(v) for k, v in obj.items()} elif isinstance(obj, list): return [serialize(v) for v in obj] elif isinstance(obj, (np.floating, np.integer)): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, torch.Tensor): return obj.cpu().numpy().tolist() return obj save_data = serialize(all_results) save_data['config'] = serialize(vars(args)) out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json') with open(out_path, 'w') as f: json.dump(save_data, f, indent=2) print(f"\nAll results saved to {out_path}") return all_results def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default='cifar10') parser.add_argument('--d_hidden', type=int, default=512) parser.add_argument('--num_blocks', type=int, default=12) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=4) parser.add_argument('--sigma_bridge', type=float, default=0.05) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--term_grad_weight', type=float, default=1.0) parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/cifar10') parser.add_argument('--methods', type=str, nargs='+', default=['bp', 'dfa', 'fa', 'state_bridge', 'credit_bridge'], help='Subset of methods to run. fa = vanilla Feedback Alignment (Lillicrap 2016).') parser.add_argument('--random_targets', action='store_true', help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).') parser.add_argument('--penalty_lam', type=float, default=0.0, help='Per-block residual-branch penalty strength: add penalty_lam * mean(||f_l(h_l)||^2) ' 'to each block local loss for DFA/SB/CB. Codex round 38 Mode 2 cross-method test.') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()