1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
"""
Frozen-blocks baseline: train only embed/head with blocks frozen at random init.
Usage:
python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100
"""
import os, sys, json, argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision, torchvision.transforms as transforms
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from reproduce.train_methods import get_data, evaluate, make_model
def freeze_blocks(model):
for p in model.blocks.parameters():
p.requires_grad_(False)
for m in model.blocks.modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.eval()
def train_frozen(model, train_loader, test_loader, device, epochs, is_conv):
opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=0.01)
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
for ep in range(1, epochs + 1):
model.train()
for m in model.blocks.modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
m.eval()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
if not is_conv: x = x.view(x.size(0), -1)
loss = F.cross_entropy(model(x), y)
opt.zero_grad(); loss.backward(); opt.step()
sch.step()
if ep % 10 == 0 or ep == epochs:
acc = evaluate(model, test_loader, device, is_conv)
print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True)
return evaluate(model, test_loader, device, is_conv)
def main():
p = argparse.ArgumentParser()
p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet'])
p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456])
p.add_argument('--epochs', type=int, default=100)
p.add_argument('--gpu', type=int, default=0)
p.add_argument('--output_dir', type=str, default='results/frozen_baselines')
args = p.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
train_loader, test_loader, num_classes = get_data(args.dataset, 128)
results = {}
for seed in args.seeds:
print(f"\n--- Frozen baseline seed={seed} ---", flush=True)
torch.manual_seed(seed); np.random.seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
model, is_conv = make_model(args.arch, num_classes, device)
freeze_blocks(model)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f" {trainable}/{total} trainable params", flush=True)
acc = train_frozen(model, train_loader, test_loader, device, args.epochs, is_conv)
results[f's{seed}'] = acc
print(f" FINAL: {acc:.4f}", flush=True)
results['config'] = vars(args)
results['mean'] = float(np.mean([results[f's{s}'] for s in args.seeds]))
results['std'] = float(np.std([results[f's{s}'] for s in args.seeds], ddof=1))
out_path = os.path.join(args.output_dir, f'frozen_{args.arch}_{args.dataset}.json')
with open(out_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nSaved: {out_path}")
print(f"Frozen baseline: {results['mean']:.4f} +/- {results['std']:.4f}")
if __name__ == '__main__':
main()
|