summaryrefslogtreecommitdiff
path: root/experiments/resnet_frozen_blocks_baseline.py
blob: 787876d31a34edcf9424852ca8d66ddca752e2d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""
Frozen-blocks and shallow baselines for a small CIFAR-10 ResNet (BatchNorm,
no LayerNorm) — codex-round-10 control to test whether the DFA "active-harm"
walk-back generalizes from LN-based architectures (ViT-Mini, ResMLP) to a
BN-based residual architecture.

Conditions per seed:
  - BP shallow (num_blocks=0)
  - BP frozen-blocks (num_blocks=4 frozen)
  - BP trainable (num_blocks=4)
  - DFA shallow (num_blocks=0)
  - DFA frozen-blocks (num_blocks=4 frozen)
  - DFA trainable (num_blocks=4)

If DFA-trainable < DFA-shallow on ResNet too → claim becomes "FA fails to train
deep blocks across multiple residual architectures including BN-based" — much
harder to dismiss as LN-specific.
If DFA-trainable ≈ or > DFA-shallow on ResNet → "harmful mode is specific to LN
normalization or terminal-LN architectures" — narrower but still useful claim.

Usage:
    CUDA_VISIBLE_DEVICES=2 python experiments/resnet_frozen_blocks_baseline.py --seed 42
"""
import sys, os, argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np

from models.small_resnet import SmallResNet


def get_loaders(batch_size=128):
    tv_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    tv = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
    te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
    return (
        DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
        DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
    )


def evaluate(model, loader, dev):
    model.eval()
    n = c = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(dev), y.to(dev)
            preds = model(x).argmax(-1)
            c += (preds == y).sum().item()
            n += x.size(0)
    return c / n


def freeze_blocks(model):
    for p in model.blocks.parameters():
        p.requires_grad_(False)
    # Also keep BN running stats frozen by setting to eval()
    for m in model.blocks.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            m.eval()


def train_bp(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False):
    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
    sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    for ep in range(1, epochs + 1):
        model.train()
        if blocks_frozen:
            for m in model.blocks.modules():
                if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                    m.eval()  # keep BN stats frozen
        for x, y in train_loader:
            x, y = x.to(dev), y.to(dev)
            loss = F.cross_entropy(model(x), y)
            opt.zero_grad(); loss.backward(); opt.step()
        sch.step()
        if ep % 10 == 0 or ep == 1 or ep == epochs:
            acc = evaluate(model, test_loader, dev)
            print(f"  [{label}] ep {ep}: test_acc={acc:.4f}", flush=True)
    return model


def train_dfa(model, train_loader, test_loader, dev, epochs, lr, wd, label, blocks_frozen=False):
    """DFA on the BN-ResNet:
    - head trained with true CE on the pooled hidden state
    - stem (conv + BN) trained via DFA-style local loss with random feedback
    - blocks (if any) skipped (frozen for blocks_frozen=True; for trainable case, the
      naive analog would be DFA-style local loss per block, but this script focuses on
      the frozen/shallow comparison; for trainable comparison use the existing ResMLP
      experiment as the analogous "trainable" since they share the same ad-hoc DFA pattern).
    For this experiment we focus on the frozen and shallow conditions.
    """
    d_hidden = model.d_hidden
    L = max(model.num_blocks, 1)
    C = 10
    Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)]

    stem_params = list(model.stem_conv.parameters()) + list(model.stem_bn.parameters())
    stem_opt = optim.AdamW(stem_params, lr=lr, weight_decay=wd)
    head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
    sch1 = optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=epochs)
    sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)

    for ep in range(1, epochs + 1):
        model.train()
        if blocks_frozen:
            for m in model.blocks.modules():
                if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                    m.eval()
        for x, y in train_loader:
            x, y = x.to(dev), y.to(dev)
            with torch.no_grad():
                logits, hi = model(x, return_hidden=True)
                e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
            hL_det = hi[-1].detach()  # (B, d_hidden, 32, 32)
            # Head update via true CE on pooled cls
            h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1)
            head_opt.zero_grad()
            F.cross_entropy(model.out_head(h_pool), y).backward()
            head_opt.step()
            # Stem update via DFA local loss
            a0 = (e_T @ Bs[0].T).detach()  # (B, d_hidden)
            rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
            h0 = model.stem(x)  # (B, d_hidden, 32, 32)
            # Broadcast credit across spatial positions: (B, d, 1, 1) -> (B, d, H, W)
            a0_b = (a0 / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h0)
            stem_loss = (h0 * a0_b).sum(dim=1).mean()  # average over batch and spatial
            stem_opt.zero_grad()
            stem_loss.backward()
            stem_opt.step()
        sch1.step(); sch2.step()
        if ep % 10 == 0 or ep == 1 or ep == epochs:
            acc = evaluate(model, test_loader, dev)
            print(f"  [{label}] ep {ep}: test_acc={acc:.4f}", flush=True)
    return model


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--epochs', type=int, default=60)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--wd', type=float, default=0.01)
    parser.add_argument('--d_hidden', type=int, default=64)
    args = parser.parse_args()

    dev = torch.device('cuda:0')
    print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True)
    train_loader, test_loader = get_loaders(batch_size=128)

    results = {}
    C = 10

    # Trainable BP (full 4-block ResNet)
    print(f"\n=== BP trainable (SmallResNet num_blocks=4), seed={args.seed} ===", flush=True)
    torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
    m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
    print(f"  n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
    train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-trainable')
    results['bp_trainable'] = evaluate(m, test_loader, dev)
    print(f"FINAL BP-trainable: {results['bp_trainable']:.4f}", flush=True)

    # Trainable DFA — block-level DFA on ResNet (each block as a unit)
    print(f"\n=== DFA trainable (SmallResNet num_blocks=4 block-level DFA), seed={args.seed} ===", flush=True)
    torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
    m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
    # We use the same approach as ViT/ResMLP: stem trained with DFA, blocks trained
    # with their own DFA-style local loss per block, head with true CE.
    # For simplicity reuse train_dfa logic but extend it to also train blocks.
    # Since this script focuses on frozen/shallow control, we'll do trainable in a
    # separate inner loop here.
    d_hidden = m.d_hidden; L = m.num_blocks
    Bs = [torch.randn(d_hidden, C, device=dev) / np.sqrt(C) for _ in range(L)]
    block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in m.blocks]
    stem_params = list(m.stem_conv.parameters()) + list(m.stem_bn.parameters())
    stem_opt = optim.AdamW(stem_params, lr=args.lr, weight_decay=args.wd)
    head_opt = optim.AdamW(m.out_head.parameters(), lr=args.lr, weight_decay=args.wd)
    all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \
              [optim.lr_scheduler.CosineAnnealingLR(stem_opt, T_max=args.epochs),
               optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]
    for ep in range(1, args.epochs + 1):
        m.train()
        for x, y in train_loader:
            x, y = x.to(dev), y.to(dev)
            with torch.no_grad():
                logits, hi = m(x, return_hidden=True)
                e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
            hL_det = hi[-1].detach()
            h_pool = F.adaptive_avg_pool2d(hL_det, 1).flatten(1)
            head_opt.zero_grad()
            F.cross_entropy(m.out_head(h_pool), y).backward()
            head_opt.step()
            for l in range(L):
                h_l = hi[l].detach()
                a_l = (e_T @ Bs[l].T).detach()
                rms = (a_l ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
                a_l_norm = (a_l / rms).unsqueeze(-1).unsqueeze(-1).expand_as(h_l)
                f_l = m.blocks[l](h_l)
                local_loss = (f_l * a_l_norm).sum(dim=1).mean()
                block_opts[l].zero_grad(); local_loss.backward()
                torch.nn.utils.clip_grad_norm_(m.blocks[l].parameters(), 1.0)
                block_opts[l].step()
            a_0 = (e_T @ Bs[0].T).detach()
            rms_0 = (a_0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
            h0 = m.stem(x)
            a_0_b = (a_0 / rms_0).unsqueeze(-1).unsqueeze(-1).expand_as(h0)
            stem_loss = (h0 * a_0_b).sum(dim=1).mean()
            stem_opt.zero_grad(); stem_loss.backward(); stem_opt.step()
        for s in all_sch: s.step()
        if ep % 10 == 0 or ep == 1 or ep == args.epochs:
            acc = evaluate(m, test_loader, dev)
            print(f"  [DFA-trainable] ep {ep}: test_acc={acc:.4f}", flush=True)
    results['dfa_trainable'] = evaluate(m, test_loader, dev)
    print(f"FINAL DFA-trainable: {results['dfa_trainable']:.4f}", flush=True)

    # BP shallow
    print(f"\n=== BP shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True)
    torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
    m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev)
    print(f"  n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
    train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-shallow')
    results['bp_shallow'] = evaluate(m, test_loader, dev)
    print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True)

    # BP frozen-blocks
    print(f"\n=== BP frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True)
    torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
    m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
    freeze_blocks(m)
    print(f"  n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True)
    train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen', blocks_frozen=True)
    results['bp_frozen'] = evaluate(m, test_loader, dev)
    print(f"FINAL BP-frozen-blocks: {results['bp_frozen']:.4f}", flush=True)

    # DFA shallow
    print(f"\n=== DFA shallow (SmallResNet num_blocks=0), seed={args.seed} ===", flush=True)
    torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
    m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=0).to(dev)
    train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-shallow')
    results['dfa_shallow'] = evaluate(m, test_loader, dev)
    print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True)

    # DFA frozen-blocks
    print(f"\n=== DFA frozen-blocks (SmallResNet num_blocks=4 frozen), seed={args.seed} ===", flush=True)
    torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
    m = SmallResNet(d_hidden=args.d_hidden, num_classes=C, num_blocks=4).to(dev)
    freeze_blocks(m)
    train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen', blocks_frozen=True)
    results['dfa_frozen'] = evaluate(m, test_loader, dev)
    print(f"FINAL DFA-frozen-blocks: {results['dfa_frozen']:.4f}", flush=True)

    print(f"\n=== Small ResNet (BatchNorm) frozen/shallow baseline summary, seed={args.seed} ===")
    for k, v in results.items():
        print(f"  {k}: {v:.4f}")
    print(f"\nKey gaps (DFA):")
    if 'dfa_shallow' in results and 'dfa_trainable' in results:
        print(f"  DFA-shallow ({results['dfa_shallow']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_shallow']-results['dfa_trainable']:+.4f}")
    if 'dfa_frozen' in results and 'dfa_trainable' in results:
        print(f"  DFA-frozen ({results['dfa_frozen']:.4f}) - DFA-trainable ({results['dfa_trainable']:.4f}) = {results['dfa_frozen']-results['dfa_trainable']:+.4f}")


if __name__ == '__main__':
    main()