summaryrefslogtreecommitdiff
path: root/experiments/resnet_frozen_blocks_baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/resnet_frozen_blocks_baseline.py')
-rw-r--r--experiments/resnet_frozen_blocks_baseline.py278
1 files changed, 278 insertions, 0 deletions
diff --git a/experiments/resnet_frozen_blocks_baseline.py b/experiments/resnet_frozen_blocks_baseline.py
new file mode 100644
index 0000000..787876d
--- /dev/null
+++ b/experiments/resnet_frozen_blocks_baseline.py
@@ -0,0 +1,278 @@
+"""
+Frozen-blocks and shallow baselines for a small CIFAR-10 ResNet (BatchNorm,
+no LayerNorm) — codex-round-10 control to test whether the DFA "active-harm"
+walk-back generalizes from LN-based architectures (ViT-Mini, ResMLP) to a
+BN-based residual architecture.
+
+Conditions per seed:
+ - BP shallow (num_blocks=0)
+ - BP frozen-blocks (num_blocks=4 frozen)
+ - BP trainable (num_blocks=4)
+ - DFA shallow (num_blocks=0)
+ - DFA frozen-blocks (num_blocks=4 frozen)
+ - DFA trainable (num_blocks=4)
+
+If DFA-trainable < DFA-shallow on ResNet too → claim becomes "FA fails to train
+deep blocks across multiple residual architectures including BN-based" — much
+harder to dismiss as LN-specific.
+If DFA-trainable ≈ or > DFA-shallow on ResNet → "harmful mode is specific to LN
+normalization or terminal-LN architectures" — narrower but still useful claim.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/resnet_frozen_blocks_baseline.py --seed 42
+"""
+import sys, os, argparse
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.small_resnet import SmallResNet
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(dev), y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+ # Also keep BN running stats frozen by setting to eval()
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+
+
+def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False):
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ if blocks_frozen:
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval() # keep BN stats frozen
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True)
+ return model
+
+
+def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False):
+ """DFA on the BN-ResNet:
+ - head trained with true CE on the pooled hidden state
+ - stem (conv + BN) trained via DFA-style local loss with random feedback
+ - blocks (if any) skipped (frozen for blocks_frozen=True; for trainable case, the
+ naive analog would be DFA-style local loss per block, but this script focuses on
+ the frozen/shallow comparison; for trainable comparison use the existing ResMLP
+ experiment as the analogous "trainable" since they share the same ad-hoc DFA pattern).
+ For this experiment we focus on the frozen and shallow conditions.
+ """
+ d_hidden = model.d_hidden
+ L = max(model.num_blocks, 1)
+ C = 10
+ Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)]
+
+ stem_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters())
+ stem_opt = optim.AdamW(stem_params, lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ sch1 = optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs)
+ sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ if blocks_frozen:
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ with torch.no_grad():
+ logits, hi = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ hL_det = hi[-1].detach() # (B, d_hidden, 32, 32)
+ # Head update via true CE on pooled cls
+ h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1)
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(h_pool), y).backward()
+ head_opt.step()
+ # Stem update via DFA local loss
+ a0 = (e_T @ Bs[0].T).detach() # (B, d_hidden)
+ rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.stem(x) # (B, d_hidden, 32, 32)
+ # Broadcast credit across spatial positions: (B, d, 1, 1) -> (B, d, H, W)
+ a0_b = (a0 / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ stem_loss = (h0 * a0_b).sum(dim=1).mean() # average over batch and spatial
+ stem_opt.zero_grad()
+ stem_loss.backward()
+ stem_opt.step()
+ sch1.step(); sch2.step()
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ print(f" [{label}] ep {ep}: test_acc={acc:.4f}", flush=True)
+ return model
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--epochs', type=int, default=60)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--d_hidden', type=int, default=64)
+ args = parser.parse_args()
+
+ dev = torch.device('cuda:0')
+ print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ results = {}
+ C = 10
+
+ # Trainable BP (full 4-block ResNet)
+ print(f"\n=== BP trainable (SmallResNet num_blocks=4), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-trainable')
+ results['bp_trainable'] = evaluate(m, test_loader, dev)
+ print(f"FINAL BP-trainable: {results['bp_trainable']:.4f}", flush=True)
+
+ # Trainable DFA — block-level DFA on ResNet (each block as a unit)
+ print(f"\n=== DFA trainable (SmallResNet num_blocks=4 block-level DFA), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
+ # We use the same approach as ViT/ResMLP: stem trained with DFA, blocks trained
+ # with their own DFA-style local loss per block, head with true CE.
+ # For simplicity reuse train_dfa logic but extend it to also train blocks.
+ # Since this script focuses on frozen/shallow control, we'll do trainable in a
+ # separate inner loop here.
+ d_hidden = m.d_hidden; L = m.num_blocks
+ Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in m.blocks]
+ stem_params = list(m.stem_conv.parameters()) + list(m.stem_bn.parameters())
+ stem_opt = optim.AdamW(stem_params, lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(m.out_head.parameters(), lr=args.lr, weight_decay=args.wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]
+ for ep in range(1, args.epochs + 1):
+ m.train()
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ with torch.no_grad():
+ logits, hi = m(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ hL_det = hi[-1].detach()
+ h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1)
+ head_opt.zero_grad()
+ F.cross_entropy(m.out_head(h_pool), y).backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hi[l].detach()
+ a_l = (e_T @ Bs[l].T).detach()
+ rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l_norm = (a_l / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l)
+ f_l = m.blocks[l](h_l)
+ local_loss = (f_l * a_l_norm).sum(dim=1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(m.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = m.stem(x)
+ a_0_b = (a_0 / rms_0).unsqueeze(-1).unsqueeze(-1).expand_as(h0)
+ stem_loss = (h0 * a_0_b).sum(dim=1).mean()
+ stem_opt.zero_grad(); stem_loss.backward(); stem_opt.step()
+ for s in all_sch: s.step()
+ if ep % 10 == 0 or ep == 1 or ep == args.epochs:
+ acc = evaluate(m, test_loader, dev)
+ print(f" [DFA-trainable] ep {ep}: test_acc={acc:.4f}", flush=True)
+ results['dfa_trainable'] = evaluate(m, test_loader, dev)
+ print(f"FINAL DFA-trainable: {results['dfa_trainable']:.4f}", flush=True)
+
+ # BP shallow
+ print(f"\n=== BP shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow')
+ results['bp_shallow'] = evaluate(m, test_loader, dev)
+ print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True)
+
+ # BP frozen-blocks
+ print(f"\n=== BP frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
+ freeze_blocks(m)
+ print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
+ train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen', blocks_frozen=True)
+ results['bp_frozen'] = evaluate(m, test_loader, dev)
+ print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True)
+
+ # DFA shallow
+ print(f"\n=== DFA shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev)
+ train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow')
+ results['dfa_shallow'] = evaluate(m, test_loader, dev)
+ print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True)
+
+ # DFA frozen-blocks
+ print(f"\n=== DFA frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
+ freeze_blocks(m)
+ train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', blocks_frozen=True)
+ results['dfa_frozen'] = evaluate(m, test_loader, dev)
+ print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True)
+
+ print(f"\n=== Small ResNet (BatchNorm) frozen/shallow baseline summary, seed={args.seed} ===")
+ for k, v in results.items():
+ print(f" {k}: {v:.4f}")
+ print(f"\nKey gaps (DFA):")
+ if 'dfa_shallow' in results and 'dfa_trainable' in results:
+ print(f" DFA-shallow ({results['dfa_shallow']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_shallow']-results['dfa_trainable']:+.4f}")
+ if 'dfa_frozen' in results and 'dfa_trainable' in results:
+ print(f" DFA-frozen ({results['dfa_frozen']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_frozen']-results['dfa_trainable']:+.4f}")
+
+
+if __name__ == '__main__':
+ main()