""" Frozen-blocks and shallow baselines for a small CIFAR-10 ResNet (BatchNorm, no LayerNorm) — codex-round-10 control to test whether the DFA "active-harm" walk-back generalizes from LN-based architectures (ViT-Mini, ResMLP) to a BN-based residual architecture. Conditions per seed: - BP shallow (num_blocks=0) - BP frozen-blocks (num_blocks=4 frozen) - BP trainable (num_blocks=4) - DFA shallow (num_blocks=0) - DFA frozen-blocks (num_blocks=4 frozen) - DFA trainable (num_blocks=4) If DFA-trainable < DFA-shallow on ResNet too → claim becomes "FA fails to train deep blocks across multiple residual architectures including BN-based" — much harder to dismiss as LN-specific. If DFA-trainable ≈ or > DFA-shallow on ResNet → "harmful mode is specific to LN normalization or terminal-LN architectures" — narrower but still useful claim. Usage: CUDA_VISIBLE_DEVICES=2 python experiments/resnet_frozen_blocks_baseline.py --seed 42 """ import sys, os, argparse sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import numpy as np from models.small_resnet import SmallResNet def get_loaders(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), ) def evaluate(model, loader, dev): model.eval() n = c = 0 with torch.no_grad(): for x, y in loader: x, y = x.to(dev), y.to(dev) preds = model(x).argmax(-1) c += (preds == y).sum().item() n += x.size(0) return c / n def freeze_blocks(model): for p in model.blocks.parameters(): p.requires_grad_(False) # Also keep BN running stats frozen by setting to eval() for m in model.blocks.modules(): if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.eval() def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False): opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() if blocks_frozen: for m in model.blocks.modules(): if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.eval() # keep BN stats frozen for x, y in train_loader: x, y = x.to(dev), y.to(dev) loss = F.cross_entropy(model(x), y) opt.zero_grad(); loss.backward(); opt.step() sch.step() if ep % 10 == 0 or ep == 1 or ep == epochs: acc = evaluate(model, test_loader, dev) print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) return model def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False): """DFA on the BN-ResNet: - head trained with true CE on the pooled hidden state - stem (conv + BN) trained via DFA-style local loss with random feedback - blocks (if any) skipped (frozen for blocks_frozen=True; for trainable case, the naive analog would be DFA-style local loss per block, but this script focuses on the frozen/shallow comparison; for trainable comparison use the existing ResMLP experiment as the analogous "trainable" since they share the same ad-hoc DFA pattern). For this experiment we focus on the frozen and shallow conditions. """ d_hidden = model.d_hidden L = max(model.num_blocks, 1) C = 10 Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)] stem_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters()) stem_opt = optim.AdamW(stem_params, lr=lr, weight_decay=wd) head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd) sch1 = optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs) sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() if blocks_frozen: for m in model.blocks.modules(): if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.eval() for x, y in train_loader: x, y = x.to(dev), y.to(dev) with torch.no_grad(): logits, hi = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 hL_det = hi[-1].detach() # (B, d_hidden, 32, 32) # Head update via true CE on pooled cls h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1) head_opt.zero_grad() F.cross_entropy(model.out_head(h_pool), y).backward() head_opt.step() # Stem update via DFA local loss a0 = (e_T @ Bs[0].T).detach() # (B, d_hidden) rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 h0 = model.stem(x) # (B, d_hidden, 32, 32) # Broadcast credit across spatial positions: (B, d, 1, 1) -> (B, d, H, W) a0_b = (a0 / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h0) stem_loss = (h0 * a0_b).sum(dim=1).mean() # average over batch and spatial stem_opt.zero_grad() stem_loss.backward() stem_opt.step() sch1.step(); sch2.step() if ep % 10 == 0 or ep == 1 or ep == epochs: acc = evaluate(model, test_loader, dev) print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True) return model def main(): parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42) parser.add_argument('--epochs', type=int, default=60) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--d_hidden', type=int, default=64) args = parser.parse_args() dev = torch.device('cuda:0') print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True) train_loader, test_loader = get_loaders(batch_size=128) results = {} C = 10 # Trainable BP (full 4-block ResNet) print(f"\n=== BP trainable (SmallResNet num_blocks=4), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-trainable') results['bp_trainable'] = evaluate(m, test_loader, dev) print(f"FINAL BP-trainable: {results['bp_trainable']:.4f}", flush=True) # Trainable DFA — block-level DFA on ResNet (each block as a unit) print(f"\n=== DFA trainable (SmallResNet num_blocks=4 block-level DFA), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) # We use the same approach as ViT/ResMLP: stem trained with DFA, blocks trained # with their own DFA-style local loss per block, head with true CE. # For simplicity reuse train_dfa logic but extend it to also train blocks. # Since this script focuses on frozen/shallow control, we'll do trainable in a # separate inner loop here. d_hidden = m.d_hidden; L = m.num_blocks Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in m.blocks] stem_params = list(m.stem_conv.parameters()) + list(m.stem_bn.parameters()) stem_opt = optim.AdamW(stem_params, lr=args.lr, weight_decay=args.wd) head_opt = optim.AdamW(m.out_head.parameters(), lr=args.lr, weight_decay=args.wd) all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \ [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=args.epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)] for ep in range(1, args.epochs + 1): m.train() for x, y in train_loader: x, y = x.to(dev), y.to(dev) with torch.no_grad(): logits, hi = m(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 hL_det = hi[-1].detach() h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1) head_opt.zero_grad() F.cross_entropy(m.out_head(h_pool), y).backward() head_opt.step() for l in range(L): h_l = hi[l].detach() a_l = (e_T @ Bs[l].T).detach() rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_l_norm = (a_l / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l) f_l = m.blocks[l](h_l) local_loss = (f_l * a_l_norm).sum(dim=1).mean() block_opts[l].zero_grad(); local_loss.backward() torch.nn.utils.clip_grad_norm_(m.blocks[l].parameters(), 1.0) block_opts[l].step() a_0 = (e_T @ Bs[0].T).detach() rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 h0 = m.stem(x) a_0_b = (a_0 / rms_0).unsqueeze(-1).unsqueeze(-1).expand_as(h0) stem_loss = (h0 * a_0_b).sum(dim=1).mean() stem_opt.zero_grad(); stem_loss.backward(); stem_opt.step() for s in all_sch: s.step() if ep % 10 == 0 or ep == 1 or ep == args.epochs: acc = evaluate(m, test_loader, dev) print(f" [DFA-trainable] ep {ep}: test_acc={acc:.4f}", flush=True) results['dfa_trainable'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-trainable: {results['dfa_trainable']:.4f}", flush=True) # BP shallow print(f"\n=== BP shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow') results['bp_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True) # BP frozen-blocks print(f"\n=== BP frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen', blocks_frozen=True) results['bp_frozen'] = evaluate(m, test_loader, dev) print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True) # DFA shallow print(f"\n=== DFA shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev) train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow') results['dfa_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) # DFA frozen-blocks print(f"\n=== DFA frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev) freeze_blocks(m) train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', blocks_frozen=True) results['dfa_frozen'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True) print(f"\n=== Small ResNet (BatchNorm) frozen/shallow baseline summary, seed={args.seed} ===") for k, v in results.items(): print(f" {k}: {v:.4f}") print(f"\nKey gaps (DFA):") if 'dfa_shallow' in results and 'dfa_trainable' in results: print(f" DFA-shallow ({results['dfa_shallow']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_shallow']-results['dfa_trainable']:+.4f}") if 'dfa_frozen' in results and 'dfa_trainable' in results: print(f" DFA-frozen ({results['dfa_frozen']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_frozen']-results['dfa_trainable']:+.4f}") if __name__ == '__main__': main()