summaryrefslogtreecommitdiff
path: root/reproduce/train_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'reproduce/train_methods.py')
-rw-r--r--reproduce/train_methods.py376
1 files changed, 376 insertions, 0 deletions
diff --git a/reproduce/train_methods.py b/reproduce/train_methods.py
new file mode 100644
index 0000000..c430b90
--- /dev/null
+++ b/reproduce/train_methods.py
@@ -0,0 +1,376 @@
+"""
+Train BP/FA/DFA on a specified architecture and compute protocol diagnostics.
+
+Usage:
+ python reproduce/train_methods.py --arch resmlp --methods bp fa dfa \
+ --seeds 42 123 456 --epochs 100 --gpu 0 --output_dir results/main_audit
+
+Architectures: resmlp (d=256 L=4), resmlp_d512_L2, vit, resnet
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from models.vit_mini import ViTMini
+from models.small_resnet import SmallResNet
+from metrics.credit_metrics import cosine_similarity_batch, nudging_test
+
+
+# ─── Data ────────────────────────────────────────────────────────────────
+
+def get_data(dataset='cifar10', batch_size=128):
+ if dataset == 'cifar10':
+ mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
+ Dataset = torchvision.datasets.CIFAR10
+ num_classes = 10
+ else:
+ mean, std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)
+ Dataset = torchvision.datasets.CIFAR100
+ num_classes = 100
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(), transforms.Normalize(mean, std)])
+ tv_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
+ tr = Dataset('./data', True, download=True, transform=tv_train)
+ te = Dataset('./data', False, download=True, transform=tv_test)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ num_classes)
+
+
+def evaluate(model, loader, device, is_conv=False):
+ model.eval()
+ c = n = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv:
+ x = x.view(x.size(0), -1)
+ c += (model(x).argmax(-1) == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+# ─── Model construction ─────────────────────────────────────────────────
+
+def make_model(arch, num_classes, device):
+ if arch == 'resmlp':
+ return ResidualMLP(3072, 256, num_classes, 4).to(device), False
+ elif arch == 'resmlp_d512_L2':
+ return ResidualMLP(3072, 512, num_classes, 2).to(device), False
+ elif arch == 'vit':
+ return ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=num_classes).to(device), True
+ elif arch == 'resnet':
+ return SmallResNet(64, num_classes, 4).to(device), True
+ else:
+ raise ValueError(f"Unknown arch: {arch}")
+
+
+# ─── Training functions ─────────────────────────────────────────────────
+
+def train_bp(model, train_loader, test_loader, device, epochs, is_conv):
+ opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv: x = x.view(x.size(0), -1)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ tl += loss.item() * x.size(0); tc += (logits.argmax(1) == y).sum().item(); tn += x.size(0)
+ sch.step()
+ log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, device, is_conv))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [BP] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log
+
+
+def _get_embed_head_params(model, is_conv):
+ """Get embed and head parameter groups."""
+ if is_conv and hasattr(model, 'stem_conv'):
+ embed_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters())
+ head_params = list(model.out_head.parameters())
+ elif hasattr(model, 'patch_embed'): # ViT
+ embed_params = list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed]
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ else: # ResMLP
+ embed_params = list(model.embed.parameters())
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ return embed_params, head_params
+
+
+def _pool_hidden(h):
+ if h.dim() == 4: return F.adaptive_avg_pool2d(h, 1).flatten(1)
+ if h.dim() == 3: return h[:, 0] # cls token
+ return h
+
+
+def _get_head_logits(model, h_pool):
+ if hasattr(model, 'out_ln'):
+ return model.out_head(model.out_ln(h_pool))
+ return model.out_head(h_pool)
+
+
+def _block_residual(model, block, h_l, is_conv):
+ """Compute block residual f_l = block(h_l) - h_l for blocks with internal skip."""
+ out = block(h_l)
+ if is_conv or hasattr(block, 'attn'): # ResNet/ViT blocks include skip internally
+ return out - h_l
+ return out # ResMLP blocks return f_l only
+
+
+def train_dfa(model, train_loader, test_loader, device, epochs, is_conv, num_classes):
+ d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model
+ L = model.num_blocks
+ C = num_classes
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
+ embed_params, head_params = _get_embed_head_params(model, is_conv)
+ embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv: x = x.view(x.size(0), -1)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ h_pool = _pool_hidden(hiddens[-1].detach())
+ head_opt.zero_grad()
+ F.cross_entropy(_get_head_logits(model, h_pool), y).backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = _block_residual(model, model.blocks[l], h_l, is_conv)
+ if f_l.dim() > 2:
+ a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l)
+ local_loss = (f_l * a_b).sum(dim=1).mean()
+ else:
+ local_loss = (f_l * a_norm).sum(-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ # Embed
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ if is_conv:
+ h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x)
+ else:
+ h0 = model.embed(x)
+ a0_n = a0 / rms0
+ if h0.dim() > 2:
+ a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ embed_loss = (h0 * a0_b).sum(dim=1).mean()
+ else:
+ embed_loss = (h0 * a0_n).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch
+ log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, device, is_conv))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [DFA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log, Bs
+
+
+def train_fa(model, train_loader, test_loader, device, epochs, is_conv, num_classes):
+ d = model.d_hidden if hasattr(model, 'd_hidden') else model.d_model
+ L = model.num_blocks
+ Bs = [torch.randn(d, d, device=device) / np.sqrt(d) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
+ embed_params, head_params = _get_embed_head_params(model, is_conv)
+ embed_opt = optim.AdamW(embed_params, lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(head_params, lr=1e-3, weight_decay=0.01)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+ for ep in range(1, epochs + 1):
+ model.train()
+ tl, tc, tn = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv: x = x.view(x.size(0), -1)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ # Head — grad before step
+ h_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True)
+ logits_out = _get_head_logits(model, h_pool)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward()
+ a_credit = h_pool.grad.detach()
+ head_opt.step()
+ # Top-down blocks
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ rms = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_credit / rms
+ f_l = _block_residual(model, model.blocks[l], h_l, is_conv)
+ if f_l.dim() > 2:
+ a_b = a_norm.unsqueeze(-1).unsqueeze(-1).expand_as(f_l)
+ local_loss = (f_l * a_b).sum(dim=1).mean()
+ else:
+ local_loss = (f_l * a_norm).sum(-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ a_credit = (a_credit @ Bs[l]).detach()
+ # Embed
+ rms0 = (a_credit ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ if is_conv:
+ h0 = model.embed(x) if hasattr(model, 'embed') else model.stem(x)
+ else:
+ h0 = model.embed(x)
+ a0_n = a_credit / rms0
+ if h0.dim() > 2:
+ a0_b = a0_n.unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ embed_loss = (h0 * a0_b).sum(dim=1).mean()
+ else:
+ embed_loss = (h0 * a0_n).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ tl += loss_val.item() * batch; tc += (logits.argmax(1) == y).sum().item(); tn += batch
+ log['train_loss'].append(tl / tn); log['train_acc'].append(tc / tn)
+ log['test_acc'].append(evaluate(model, test_loader, device, is_conv))
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [FA] ep {ep}: acc={log['test_acc'][-1]:.4f}", flush=True)
+ return log, Bs
+
+
+# ─── Diagnostics ─────────────────────────────────────────────────────────
+
+def compute_diagnostics(model, x_eval, y_eval, device, method_name, dfa_Bs=None, fa_Bs=None, is_conv=False):
+ """Compute per-layer cosine, ||g_l||, ||h_l|| and nudging."""
+ model.eval()
+ L = model.num_blocks
+
+ with torch.no_grad():
+ logits, hiddens = model(x_eval, return_hidden=True)
+
+ h_norms = [float(_pool_hidden(h).norm(dim=-1).median().item()) for h in hiddens]
+
+ # BP grads
+ h0 = model.embed(x_eval) if hasattr(model, 'embed') else model.stem(x_eval)
+ hs = [h0.clone().requires_grad_(True)]
+ for block in model.blocks:
+ hs.append(block(hs[-1]))
+ h_final = _pool_hidden(hs[-1])
+ if hasattr(model, 'out_ln'):
+ h_final = model.out_ln(h_final)
+ out_logits = model.out_head(h_final)
+ loss = F.cross_entropy(out_logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_norms = [float(_pool_hidden(g).norm(dim=-1).median().item()) for g in grads]
+
+ # Per-layer cosine
+ with torch.no_grad():
+ e_T = out_logits.softmax(-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+
+ bp_cosine = []
+ if method_name == 'bp':
+ bp_cosine = [1.0] * L
+ elif method_name == 'dfa' and dfa_Bs is not None:
+ for l in range(L):
+ a = (e_T @ dfa_Bs[l].T).detach()
+ g_pool = _pool_hidden(grads[l]).detach()
+ bp_cosine.append(cosine_similarity_batch(a, g_pool))
+ elif method_name == 'fa' and fa_Bs is not None:
+ hL_pool = _pool_hidden(hiddens[-1].detach()).requires_grad_(True)
+ logits_fa = _get_head_logits(model, hL_pool)
+ loss_fa = F.cross_entropy(logits_fa, y_eval)
+ a_credit = torch.autograd.grad(loss_fa, hL_pool)[0].detach()
+ for l in range(L - 1, -1, -1):
+ g_pool = _pool_hidden(grads[l]).detach()
+ bp_cosine.insert(0, cosine_similarity_batch(a_credit, g_pool))
+ a_credit = (a_credit @ fa_Bs[l]).detach()
+
+ model.train()
+ return {
+ 'bp_cosine': bp_cosine,
+ 'bp_grad_norms_per_layer': g_norms,
+ 'hidden_norms_per_layer': h_norms,
+ }
+
+
+# ─── Main ────────────────────────────────────────────────────────────────
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet'])
+ p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
+ p.add_argument('--methods', nargs='+', default=['bp', 'fa', 'dfa'])
+ p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456])
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/reproduce')
+ p.add_argument('--penalty_lam', type=float, default=0.0)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ train_loader, test_loader, num_classes = get_data(args.dataset, 128)
+
+ # Eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128: break
+ x_eval_raw = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ results = {}
+ for seed in args.seeds:
+ print(f"\n{'='*60}\nSeed {seed}\n{'='*60}", flush=True)
+ results[str(seed)] = {}
+
+ for method in args.methods:
+ print(f"\n--- {method.upper()} ---", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed)
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
+ model, is_conv = make_model(args.arch, num_classes, device)
+ x_eval = x_eval_raw if is_conv else x_eval_raw.view(x_eval_raw.size(0), -1)
+
+ if method == 'bp':
+ log = train_bp(model, train_loader, test_loader, device, args.epochs, is_conv)
+ diag = compute_diagnostics(model, x_eval, y_eval, device, 'bp', is_conv=is_conv)
+ results[str(seed)]['bp'] = {'log': log, 'diagnostics': diag}
+ elif method == 'dfa':
+ log, Bs = train_dfa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes)
+ diag = compute_diagnostics(model, x_eval, y_eval, device, 'dfa', dfa_Bs=Bs, is_conv=is_conv)
+ results[str(seed)]['dfa'] = {'log': log, 'diagnostics': diag}
+ elif method == 'fa':
+ log, Bs = train_fa(model, train_loader, test_loader, device, args.epochs, is_conv, num_classes)
+ diag = compute_diagnostics(model, x_eval, y_eval, device, 'fa', fa_Bs=Bs, is_conv=is_conv)
+ results[str(seed)]['fa'] = {'log': log, 'diagnostics': diag}
+
+ results['config'] = vars(args)
+ out_path = os.path.join(args.output_dir, f'results_{args.dataset}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()