diff options
Diffstat (limited to 'experiments/resnet_frozen_blocks_baseline.py')
| -rw-r--r-- | experiments/resnet_frozen_blocks_baseline.py | 278 |
1 files changed, 278 insertions, 0 deletions
diff --git a/experiments/resnet_frozen_blocks_baseline.py b/experiments/resnet_frozen_blocks_baseline.py new file mode 100644 index 0000000..787876d --- /dev/null +++ b/experiments/resnet_frozen_blocks_baseline.py @@ -0,0 +1,278 @@ +""" +Frozen-blocks and shallow baselines for a small CIFAR-10 ResNet (BatchNorm, +no LayerNorm) — codex-round-10 control to test whether the DFA "active-harm" +walk-back generalizes from LN-based architectures (ViT-Mini, ResMLP) to a +BN-based residual architecture. + +Conditions per seed: + - BP shallow (num_blocks=0) + - BP frozen-blocks (num_blocks=4 frozen) + - BP trainable (num_blocks=4) + - DFA shallow (num_blocks=0) + - DFA frozen-blocks (num_blocks=4 frozen) + - DFA trainable (num_blocks=4) + +If DFA-trainable < DFA-shallow on ResNet too → claim becomes "FA fails to train +deep blocks across multiple residual architectures including BN-based" — much +harder to dismiss as LN-specific. +If DFA-trainable ≈ or > DFA-shallow on ResNet → "harmful mode is specific to LN +normalization or terminal-LN architectures" — narrower but still useful claim. + +Usage: + CUDA_VISIBLE_DEVICES=2 python experiments/resnet_frozen_blocks_baseline.py --seed 42 +""" +import sys, os, argparse +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.small_resnet import SmallResNet + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + # Also keep BN running stats frozen by setting to eval() + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + + +def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False): + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + if blocks_frozen: + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() # keep BN stats frozen + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == 1 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) + return model + + +def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False): + """DFA on the BN-ResNet: + - head trained with true CE on the pooled hidden state + - stem (conv + BN) trained via DFA-style local loss with random feedback + - blocks (if any) skipped (frozen for blocks_frozen=True; for trainable case, the + naive analog would be DFA-style local loss per block, but this script focuses on + the frozen/shallow comparison; for trainable comparison use the existing ResMLP + experiment as the analogous "trainable" since they share the same ad-hoc DFA pattern). + For this experiment we focus on the frozen and shallow conditions. + """ + d_hidden = model.d_hidden + L = max(model.num_blocks, 1) + C = 10 + Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)] + + stem_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()) + stem_opt = optim.AdamW(stem_params, lr=lr, weight_decay=wd) + head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) + sch1 = optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs) + sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs) + + for ep in range(1, epochs + 1): + model.train() + if blocks_frozen: + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + with torch.no_grad(): + logits, hi = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 + hL_det = hi[-1].detach() # (B, d_hidden, 32, 32) + # Head update via true CE on pooled cls + h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1) + head_opt.zero_grad() + F.cross_entropy(model.out_head(h_pool), y).backward() + head_opt.step() + # Stem update via DFA local loss + a0 = (e_T @ Bs[0].T).detach() # (B, d_hidden) + rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.stem(x) # (B, d_hidden, 32, 32) + # Broadcast credit across spatial positions: (B, d, 1, 1) -> (B, d, H, W) + a0_b = (a0 / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h0) + stem_loss = (h0 * a0_b).sum(dim=1).mean() # average over batch and spatial + stem_opt.zero_grad() + stem_loss.backward() + stem_opt.step() + sch1.step(); sch2.step() + if ep % 10 == 0 or ep == 1 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) + return model + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--epochs', type=int, default=60) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--d_hidden', type=int, default=64) + args = parser.parse_args() + + dev = torch.device('cuda:0') + print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + + results = {} + C = 10 + + # Trainable BP (full 4-block ResNet) + print(f"\n=== BP trainable (SmallResNet num_blocks=4), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-trainable') + results['bp_trainable'] = evaluate(m, test_loader, dev) + print(f"FINAL BP-trainable: {results['bp_trainable']:.4f}", flush=True) + + # Trainable DFA — block-level DFA on ResNet (each block as a unit) + print(f"\n=== DFA trainable (SmallResNet num_blocks=4 block-level DFA), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) + # We use the same approach as ViT/ResMLP: stem trained with DFA, blocks trained + # with their own DFA-style local loss per block, head with true CE. + # For simplicity reuse train_dfa logic but extend it to also train blocks. + # Since this script focuses on frozen/shallow control, we'll do trainable in a + # separate inner loop here. + d_hidden = m.d_hidden; L = m.num_blocks + Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in m.blocks] + stem_params = list(m.stem_conv.parameters()) + list(m.stem_bn.parameters()) + stem_opt = optim.AdamW(stem_params, lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW(m.out_head.parameters(), lr=args.lr, weight_decay=args.wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)] + for ep in range(1, args.epochs + 1): + m.train() + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + with torch.no_grad(): + logits, hi = m(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 + hL_det = hi[-1].detach() + h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1) + head_opt.zero_grad() + F.cross_entropy(m.out_head(h_pool), y).backward() + head_opt.step() + for l in range(L): + h_l = hi[l].detach() + a_l = (e_T @ Bs[l].T).detach() + rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_l_norm = (a_l / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l) + f_l = m.blocks[l](h_l) + local_loss = (f_l * a_l_norm).sum(dim=1).mean() + block_opts[l].zero_grad(); local_loss.backward() + torch.nn.utils.clip_grad_norm_(m.blocks[l].parameters(), 1.0) + block_opts[l].step() + a_0 = (e_T @ Bs[0].T).detach() + rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = m.stem(x) + a_0_b = (a_0 / rms_0).unsqueeze(-1).unsqueeze(-1).expand_as(h0) + stem_loss = (h0 * a_0_b).sum(dim=1).mean() + stem_opt.zero_grad(); stem_loss.backward(); stem_opt.step() + for s in all_sch: s.step() + if ep % 10 == 0 or ep == 1 or ep == args.epochs: + acc = evaluate(m, test_loader, dev) + print(f" [DFA-trainable] ep {ep}: test_acc={acc:.4f}", flush=True) + results['dfa_trainable'] = evaluate(m, test_loader, dev) + print(f"FINAL DFA-trainable: {results['dfa_trainable']:.4f}", flush=True) + + # BP shallow + print(f"\n=== BP shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow') + results['bp_shallow'] = evaluate(m, test_loader, dev) + print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True) + + # BP frozen-blocks + print(f"\n=== BP frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) + freeze_blocks(m) + print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) + train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen', blocks_frozen=True) + results['bp_frozen'] = evaluate(m, test_loader, dev) + print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True) + + # DFA shallow + print(f"\n=== DFA shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev) + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow') + results['dfa_shallow'] = evaluate(m, test_loader, dev) + print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) + + # DFA frozen-blocks + print(f"\n=== DFA frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) + freeze_blocks(m) + train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', blocks_frozen=True) + results['dfa_frozen'] = evaluate(m, test_loader, dev) + print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True) + + print(f"\n=== Small ResNet (BatchNorm) frozen/shallow baseline summary, seed={args.seed} ===") + for k, v in results.items(): + print(f" {k}: {v:.4f}") + print(f"\nKey gaps (DFA):") + if 'dfa_shallow' in results and 'dfa_trainable' in results: + print(f" DFA-shallow ({results['dfa_shallow']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_shallow']-results['dfa_trainable']:+.4f}") + if 'dfa_frozen' in results and 'dfa_trainable' in results: + print(f" DFA-frozen ({results['dfa_frozen']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_frozen']-results['dfa_trainable']:+.4f}") + + +if __name__ == '__main__': + main() |
