""" Train BP/FA/DFA on a specified architecture and compute protocol diagnostics. Usage: python reproduce/train_methods.py --arch resmlp --methods bp fa dfa \ --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/main_audit Architectures: resmlp (d=256 L=4), resmlp_d512_L2, vit, resnet """ import os, sys, json, argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision, torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.vit_mini import ViTMini from models.small_resnet import SmallResNet from metrics.credit_metrics import cosine_similarity_batch, nudging_test # ─── Data ──────────────────────────────────────────────────────────────── def get_data(dataset='cifar10', batch_size=128): if dataset == 'cifar10': mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) Dataset = torchvision.datasets.CIFAR10 num_classes = 10 else: mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) Dataset = torchvision.datasets.CIFAR100 num_classes = 100 tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean, std)]) tv_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) tr = Dataset('./data', True, download=True, transform=tv_train) te = Dataset('./data', False, download=True, transform=tv_test) return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), num_classes) def evaluate(model, loader, device, is_conv=False): model.eval() c = n = 0 with torch.no_grad(): for x, y in loader: x, y = x.to(device), y.to(device) if not is_conv: x = x.view(x.size(0), -1) c += (model(x).argmax(-1) == y).sum().item() n += x.size(0) return c / n # ─── Model construction ───────────────────────────────────────────────── def make_model(arch, num_classes, device): if arch == 'resmlp': return ResidualMLP(3072, 256, num_classes, 4).to(device), False elif arch == 'resmlp_d512_L2': return ResidualMLP(3072, 512, num_classes, 2).to(device), False elif arch == 'vit': return ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=num_classes).to(device), True elif arch == 'resnet': return SmallResNet(64, num_classes, 4).to(device), True else: raise ValueError(f"Unknown arch: {arch}") # ─── Training functions ───────────────────────────────────────────────── def train_bp(model, train_loader, test_loader, device, epochs, is_conv): opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) log = {'train_loss': [], 'train_acc': [], 'test_acc': []} for ep in range(1, epochs + 1): model.train() tl, tc, tn = 0, 0, 0 for x, y in train_loader: x, y = x.to(device), y.to(device) if not is_conv: x = x.view(x.size(0), -1) logits = model(x) loss = F.cross_entropy(logits, y) opt.zero_grad(); loss.backward(); opt.step() tl += loss.item() * x.size(0); tc += (logits.argmax(1) == y).sum().item(); tn += x.size(0) sch.step() log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) if ep % 10 == 0 or ep == epochs: print(f" [BP] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) return log def _get_embed_head_params(model, is_conv): """Get embed and head parameter groups.""" if is_conv and hasattr(model, 'stem_conv'): embed_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()) head_params = list(model.out_head.parameters()) elif hasattr(model, 'patch_embed'): # ViT embed_params = list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed] head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) else: # ResMLP embed_params = list(model.embed.parameters()) head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) return embed_params, head_params def _pool_hidden(h): if h.dim() == 4: return F.adaptive_avg_pool2d(h, 1).flatten(1) if h.dim() == 3: return h[:, 0] # cls token return h def _get_head_logits(model, h_pool): if hasattr(model, 'out_ln'): return model.out_head(model.out_ln(h_pool)) return model.out_head(h_pool) def _block_residual(model, block, h_l, is_conv): """Compute block residual f_l = block(h_l) - h_l for blocks with internal skip.""" out = block(h_l) if is_conv or hasattr(block, 'attn'): # ResNet/ViT blocks include skip internally return out - h_l return out # ResMLP blocks return f_l only def train_dfa(model, train_loader, test_loader, device, epochs, is_conv, num_classes): d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model L = model.num_blocks C = num_classes Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] embed_params, head_params = _get_embed_head_params(model, is_conv) embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01) head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01) all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] log = {'train_loss': [], 'train_acc': [], 'test_acc': []} for ep in range(1, epochs + 1): model.train() tl, tc, tn = 0, 0, 0 for x, y in train_loader: x, y = x.to(device), y.to(device) if not is_conv: x = x.view(x.size(0), -1) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 h_pool = _pool_hidden(hiddens[-1].detach()) head_opt.zero_grad() F.cross_entropy(_get_head_logits(model, h_pool), y).backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a = (e_T @ Bs[l].T).detach() rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = _block_residual(model, model.blocks[l], h_l, is_conv) if f_l.dim() > 2: a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l) local_loss = (f_l * a_b).sum(dim=1).mean() else: local_loss = (f_l * a_norm).sum(-1).mean() block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() # Embed a0 = (e_T @ Bs[0].T).detach() rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 if is_conv: h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x) else: h0 = model.embed(x) a0_n = a0 / rms0 if h0.dim() > 2: a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0) embed_loss = (h0 * a0_b).sum(dim=1).mean() else: embed_loss = (h0 * a0_n).sum(-1).mean() embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() for s in all_sch: s.step() tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) if ep % 10 == 0 or ep == epochs: print(f" [DFA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) return log, Bs def train_fa(model, train_loader, test_loader, device, epochs, is_conv, num_classes): d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model L = model.num_blocks Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] embed_params, head_params = _get_embed_head_params(model, is_conv) embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01) head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01) all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] log = {'train_loss': [], 'train_acc': [], 'test_acc': []} for ep in range(1, epochs + 1): model.train() tl, tc, tn = 0, 0, 0 for x, y in train_loader: x, y = x.to(device), y.to(device) if not is_conv: x = x.view(x.size(0), -1) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) # Head — grad before step h_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True) logits_out = _get_head_logits(model, h_pool) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad(); loss_out.backward() a_credit = h_pool.grad.detach() head_opt.step() # Top-down blocks for l in range(L - 1, -1, -1): h_l = hiddens[l].detach() rms = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_norm = a_credit / rms f_l = _block_residual(model, model.blocks[l], h_l, is_conv) if f_l.dim() > 2: a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l) local_loss = (f_l * a_b).sum(dim=1).mean() else: local_loss = (f_l * a_norm).sum(-1).mean() block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() a_credit = (a_credit @ Bs[l]).detach() # Embed rms0 = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 if is_conv: h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x) else: h0 = model.embed(x) a0_n = a_credit / rms0 if h0.dim() > 2: a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0) embed_loss = (h0 * a0_b).sum(dim=1).mean() else: embed_loss = (h0 * a0_n).sum(-1).mean() embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() for s in all_sch: s.step() tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) if ep % 10 == 0 or ep == epochs: print(f" [FA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) return log, Bs # ─── Diagnostics ───────────────────────────────────────────────────────── def compute_diagnostics(model, x_eval, y_eval, device, method_name, dfa_Bs=None, fa_Bs=None, is_conv=False): """Compute per-layer cosine, ||g_l||, ||h_l|| and nudging.""" model.eval() L = model.num_blocks with torch.no_grad(): logits, hiddens = model(x_eval, return_hidden=True) h_norms = [float(_pool_hidden(h).norm(dim=-1).median().item()) for h in hiddens] # BP grads h0 = model.embed(x_eval) if hasattr(model, 'embed') else model.stem(x_eval) hs = [h0.clone().requires_grad_(True)] for block in model.blocks: hs.append(block(hs[-1])) h_final = _pool_hidden(hs[-1]) if hasattr(model, 'out_ln'): h_final = model.out_ln(h_final) out_logits = model.out_head(h_final) loss = F.cross_entropy(out_logits, y_eval) grads = torch.autograd.grad(loss, hs) g_norms = [float(_pool_hidden(g).norm(dim=-1).median().item()) for g in grads] # Per-layer cosine with torch.no_grad(): e_T = out_logits.softmax(-1) e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 bp_cosine = [] if method_name == 'bp': bp_cosine = [1.0] * L elif method_name == 'dfa' and dfa_Bs is not None: for l in range(L): a = (e_T @ dfa_Bs[l].T).detach() g_pool = _pool_hidden(grads[l]).detach() bp_cosine.append(cosine_similarity_batch(a, g_pool)) elif method_name == 'fa' and fa_Bs is not None: hL_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True) logits_fa = _get_head_logits(model, hL_pool) loss_fa = F.cross_entropy(logits_fa, y_eval) a_credit = torch.autograd.grad(loss_fa, hL_pool)[0].detach() for l in range(L - 1, -1, -1): g_pool = _pool_hidden(grads[l]).detach() bp_cosine.insert(0, cosine_similarity_batch(a_credit, g_pool)) a_credit = (a_credit @ fa_Bs[l]).detach() model.train() return { 'bp_cosine': bp_cosine, 'bp_grad_norms_per_layer': g_norms, 'hidden_norms_per_layer': h_norms, } # ─── Main ──────────────────────────────────────────────────────────────── def main(): p = argparse.ArgumentParser() p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet']) p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) p.add_argument('--methods', nargs='+', default=['bp', 'fa', 'dfa']) p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) p.add_argument('--epochs', type=int, default=100) p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/reproduce') p.add_argument('--penalty_lam', type=float, default=0.0) args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') train_loader, test_loader, num_classes = get_data(args.dataset, 128) # Eval buffer xs, ys = [], [] for x, y in test_loader: xs.append(x); ys.append(y) if sum(xb.size(0) for xb in xs) >= 128: break x_eval_raw = torch.cat(xs)[:128].to(device) y_eval = torch.cat(ys)[:128].to(device) results = {} for seed in args.seeds: print(f"\n{'='*60}\nSeed {seed}\n{'='*60}", flush=True) results[str(seed)] = {} for method in args.methods: print(f"\n--- {method.upper()} ---", flush=True) torch.manual_seed(seed); np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) model, is_conv = make_model(args.arch, num_classes, device) x_eval = x_eval_raw if is_conv else x_eval_raw.view(x_eval_raw.size(0), -1) if method == 'bp': log = train_bp(model, train_loader, test_loader, device, args.epochs, is_conv) diag = compute_diagnostics(model, x_eval, y_eval, device, 'bp', is_conv=is_conv) results[str(seed)]['bp'] = {'log': log, 'diagnostics': diag} elif method == 'dfa': log, Bs = train_dfa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes) diag = compute_diagnostics(model, x_eval, y_eval, device, 'dfa', dfa_Bs=Bs, is_conv=is_conv) results[str(seed)]['dfa'] = {'log': log, 'diagnostics': diag} elif method == 'fa': log, Bs = train_fa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes) diag = compute_diagnostics(model, x_eval, y_eval, device, 'fa', fa_Bs=Bs, is_conv=is_conv) results[str(seed)]['fa'] = {'log': log, 'diagnostics': diag} results['config'] = vars(args) out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json') with open(out_path, 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved: {out_path}", flush=True) if __name__ == '__main__': main()