From b480d0cdc21f944e4adccf6e81cc939b0450c5e9 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 19:50:45 -0500 Subject: Initial submission code: FA evaluation protocol + reproduction scripts Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) --- reproduce/train_methods.py | 376 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 376 insertions(+) create mode 100644 reproduce/train_methods.py (limited to 'reproduce/train_methods.py') diff --git a/reproduce/train_methods.py b/reproduce/train_methods.py new file mode 100644 index 0000000..c430b90 --- /dev/null +++ b/reproduce/train_methods.py @@ -0,0 +1,376 @@ +""" +Train BP/FA/DFA on a specified architecture and compute protocol diagnostics. + +Usage: + python reproduce/train_methods.py --arch resmlp --methods bp fa dfa \ + --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/main_audit + +Architectures: resmlp (d=256 L=4), resmlp_d512_L2, vit, resnet +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from models.vit_mini import ViTMini +from models.small_resnet import SmallResNet +from metrics.credit_metrics import cosine_similarity_batch, nudging_test + + +# ─── Data ──────────────────────────────────────────────────────────────── + +def get_data(dataset='cifar10', batch_size=128): + if dataset == 'cifar10': + mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) + Dataset = torchvision.datasets.CIFAR10 + num_classes = 10 + else: + mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) + Dataset = torchvision.datasets.CIFAR100 + num_classes = 100 + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), transforms.Normalize(mean, std)]) + tv_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) + tr = Dataset('./data', True, download=True, transform=tv_train) + te = Dataset('./data', False, download=True, transform=tv_test) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + num_classes) + + +def evaluate(model, loader, device, is_conv=False): + model.eval() + c = n = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(device), y.to(device) + if not is_conv: + x = x.view(x.size(0), -1) + c += (model(x).argmax(-1) == y).sum().item() + n += x.size(0) + return c / n + + +# ─── Model construction ───────────────────────────────────────────────── + +def make_model(arch, num_classes, device): + if arch == 'resmlp': + return ResidualMLP(3072, 256, num_classes, 4).to(device), False + elif arch == 'resmlp_d512_L2': + return ResidualMLP(3072, 512, num_classes, 2).to(device), False + elif arch == 'vit': + return ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=num_classes).to(device), True + elif arch == 'resnet': + return SmallResNet(64, num_classes, 4).to(device), True + else: + raise ValueError(f"Unknown arch: {arch}") + + +# ─── Training functions ───────────────────────────────────────────────── + +def train_bp(model, train_loader, test_loader, device, epochs, is_conv): + opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if not is_conv: x = x.view(x.size(0), -1) + logits = model(x) + loss = F.cross_entropy(logits, y) + opt.zero_grad(); loss.backward(); opt.step() + tl += loss.item() * x.size(0); tc += (logits.argmax(1) == y).sum().item(); tn += x.size(0) + sch.step() + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) + if ep % 10 == 0 or ep == epochs: + print(f" [BP] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log + + +def _get_embed_head_params(model, is_conv): + """Get embed and head parameter groups.""" + if is_conv and hasattr(model, 'stem_conv'): + embed_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()) + head_params = list(model.out_head.parameters()) + elif hasattr(model, 'patch_embed'): # ViT + embed_params = list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed] + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + else: # ResMLP + embed_params = list(model.embed.parameters()) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + return embed_params, head_params + + +def _pool_hidden(h): + if h.dim() == 4: return F.adaptive_avg_pool2d(h, 1).flatten(1) + if h.dim() == 3: return h[:, 0] # cls token + return h + + +def _get_head_logits(model, h_pool): + if hasattr(model, 'out_ln'): + return model.out_head(model.out_ln(h_pool)) + return model.out_head(h_pool) + + +def _block_residual(model, block, h_l, is_conv): + """Compute block residual f_l = block(h_l) - h_l for blocks with internal skip.""" + out = block(h_l) + if is_conv or hasattr(block, 'attn'): # ResNet/ViT blocks include skip internally + return out - h_l + return out # ResMLP blocks return f_l only + + +def train_dfa(model, train_loader, test_loader, device, epochs, is_conv, num_classes): + d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model + L = model.num_blocks + C = num_classes + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_params, head_params = _get_embed_head_params(model, is_conv) + embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if not is_conv: x = x.view(x.size(0), -1) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + h_pool = _pool_hidden(hiddens[-1].detach()) + head_opt.zero_grad() + F.cross_entropy(_get_head_logits(model, h_pool), y).backward() + head_opt.step() + for l in range(L): + h_l = hiddens[l].detach() + a = (e_T @ Bs[l].T).detach() + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = _block_residual(model, model.blocks[l], h_l, is_conv) + if f_l.dim() > 2: + a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l) + local_loss = (f_l * a_b).sum(dim=1).mean() + else: + local_loss = (f_l * a_norm).sum(-1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + # Embed + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + if is_conv: + h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x) + else: + h0 = model.embed(x) + a0_n = a0 / rms0 + if h0.dim() > 2: + a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0) + embed_loss = (h0 * a0_b).sum(dim=1).mean() + else: + embed_loss = (h0 * a0_n).sum(-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) + if ep % 10 == 0 or ep == epochs: + print(f" [DFA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log, Bs + + +def train_fa(model, train_loader, test_loader, device, epochs, is_conv, num_classes): + d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model + L = model.num_blocks + Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_params, head_params = _get_embed_head_params(model, is_conv) + embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + for ep in range(1, epochs + 1): + model.train() + tl, tc, tn = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if not is_conv: x = x.view(x.size(0), -1) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + # Head — grad before step + h_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True) + logits_out = _get_head_logits(model, h_pool) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward() + a_credit = h_pool.grad.detach() + head_opt.step() + # Top-down blocks + for l in range(L - 1, -1, -1): + h_l = hiddens[l].detach() + rms = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_credit / rms + f_l = _block_residual(model, model.blocks[l], h_l, is_conv) + if f_l.dim() > 2: + a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l) + local_loss = (f_l * a_b).sum(dim=1).mean() + else: + local_loss = (f_l * a_norm).sum(-1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + a_credit = (a_credit @ Bs[l]).detach() + # Embed + rms0 = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + if is_conv: + h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x) + else: + h0 = model.embed(x) + a0_n = a_credit / rms0 + if h0.dim() > 2: + a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0) + embed_loss = (h0 * a0_b).sum(dim=1).mean() + else: + embed_loss = (h0 * a0_n).sum(-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch + log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn) + log['test_acc'].append(evaluate(model, test_loader, device, is_conv)) + if ep % 10 == 0 or ep == epochs: + print(f" [FA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True) + return log, Bs + + +# ─── Diagnostics ───────────────────────────────────────────────────────── + +def compute_diagnostics(model, x_eval, y_eval, device, method_name, dfa_Bs=None, fa_Bs=None, is_conv=False): + """Compute per-layer cosine, ||g_l||, ||h_l|| and nudging.""" + model.eval() + L = model.num_blocks + + with torch.no_grad(): + logits, hiddens = model(x_eval, return_hidden=True) + + h_norms = [float(_pool_hidden(h).norm(dim=-1).median().item()) for h in hiddens] + + # BP grads + h0 = model.embed(x_eval) if hasattr(model, 'embed') else model.stem(x_eval) + hs = [h0.clone().requires_grad_(True)] + for block in model.blocks: + hs.append(block(hs[-1])) + h_final = _pool_hidden(hs[-1]) + if hasattr(model, 'out_ln'): + h_final = model.out_ln(h_final) + out_logits = model.out_head(h_final) + loss = F.cross_entropy(out_logits, y_eval) + grads = torch.autograd.grad(loss, hs) + g_norms = [float(_pool_hidden(g).norm(dim=-1).median().item()) for g in grads] + + # Per-layer cosine + with torch.no_grad(): + e_T = out_logits.softmax(-1) + e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + + bp_cosine = [] + if method_name == 'bp': + bp_cosine = [1.0] * L + elif method_name == 'dfa' and dfa_Bs is not None: + for l in range(L): + a = (e_T @ dfa_Bs[l].T).detach() + g_pool = _pool_hidden(grads[l]).detach() + bp_cosine.append(cosine_similarity_batch(a, g_pool)) + elif method_name == 'fa' and fa_Bs is not None: + hL_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True) + logits_fa = _get_head_logits(model, hL_pool) + loss_fa = F.cross_entropy(logits_fa, y_eval) + a_credit = torch.autograd.grad(loss_fa, hL_pool)[0].detach() + for l in range(L - 1, -1, -1): + g_pool = _pool_hidden(grads[l]).detach() + bp_cosine.insert(0, cosine_similarity_batch(a_credit, g_pool)) + a_credit = (a_credit @ fa_Bs[l]).detach() + + model.train() + return { + 'bp_cosine': bp_cosine, + 'bp_grad_norms_per_layer': g_norms, + 'hidden_norms_per_layer': h_norms, + } + + +# ─── Main ──────────────────────────────────────────────────────────────── + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet']) + p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) + p.add_argument('--methods', nargs='+', default=['bp', 'fa', 'dfa']) + p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/reproduce') + p.add_argument('--penalty_lam', type=float, default=0.0) + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + train_loader, test_loader, num_classes = get_data(args.dataset, 128) + + # Eval buffer + xs, ys = [], [] + for x, y in test_loader: + xs.append(x); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: break + x_eval_raw = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + results = {} + for seed in args.seeds: + print(f"\n{'='*60}\nSeed {seed}\n{'='*60}", flush=True) + results[str(seed)] = {} + + for method in args.methods: + print(f"\n--- {method.upper()} ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + model, is_conv = make_model(args.arch, num_classes, device) + x_eval = x_eval_raw if is_conv else x_eval_raw.view(x_eval_raw.size(0), -1) + + if method == 'bp': + log = train_bp(model, train_loader, test_loader, device, args.epochs, is_conv) + diag = compute_diagnostics(model, x_eval, y_eval, device, 'bp', is_conv=is_conv) + results[str(seed)]['bp'] = {'log': log, 'diagnostics': diag} + elif method == 'dfa': + log, Bs = train_dfa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes) + diag = compute_diagnostics(model, x_eval, y_eval, device, 'dfa', dfa_Bs=Bs, is_conv=is_conv) + results[str(seed)]['dfa'] = {'log': log, 'diagnostics': diag} + elif method == 'fa': + log, Bs = train_fa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes) + diag = compute_diagnostics(model, x_eval, y_eval, device, 'fa', fa_Bs=Bs, is_conv=is_conv) + results[str(seed)]['fa'] = {'log': log, 'diagnostics': diag} + + results['config'] = vars(args) + out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {out_path}", flush=True) + + +if __name__ == '__main__': + main() -- cgit v1.2.3