summaryrefslogtreecommitdiff
path: root/reproduce/frozen_baseline.py
blob: 08368a2a948cbe4c9ac08025d2ac14bfa87e951e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Frozen-blocks baseline: train only embed/head with blocks frozen at random init.

Usage:
    python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100
"""
import os, sys, json, argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision, torchvision.transforms as transforms

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from reproduce.train_methods import get_data, evaluate, make_model


def freeze_blocks(model):
    for p in model.blocks.parameters():
        p.requires_grad_(False)
    for m in model.blocks.modules():
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
            m.eval()


def train_frozen(model, train_loader, test_loader, device, epochs, is_conv):
    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=0.01)
    sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    for ep in range(1, epochs + 1):
        model.train()
        for m in model.blocks.modules():
            if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
                m.eval()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            if not is_conv: x = x.view(x.size(0), -1)
            loss = F.cross_entropy(model(x), y)
            opt.zero_grad(); loss.backward(); opt.step()
        sch.step()
        if ep % 10 == 0 or ep == epochs:
            acc = evaluate(model, test_loader, device, is_conv)
            print(f"  [Frozen] ep {ep}: acc={acc:.4f}", flush=True)
    return evaluate(model, test_loader, device, is_conv)


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet'])
    p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
    p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456])
    p.add_argument('--epochs', type=int, default=100)
    p.add_argument('--gpu', type=int, default=0)
    p.add_argument('--output_dir', type=str, default='results/frozen_baselines')
    args = p.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    train_loader, test_loader, num_classes = get_data(args.dataset, 128)

    results = {}
    for seed in args.seeds:
        print(f"\n--- Frozen baseline seed={seed} ---", flush=True)
        torch.manual_seed(seed); np.random.seed(seed)
        if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
        model, is_conv = make_model(args.arch, num_classes, device)
        freeze_blocks(model)
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        print(f"  {trainable}/{total} trainable params", flush=True)
        acc = train_frozen(model, train_loader, test_loader, device, args.epochs, is_conv)
        results[f's{seed}'] = acc
        print(f"  FINAL: {acc:.4f}", flush=True)

    results['config'] = vars(args)
    results['mean'] = float(np.mean([results[f's{s}'] for s in args.seeds]))
    results['std'] = float(np.std([results[f's{s}'] for s in args.seeds], ddof=1))
    out_path = os.path.join(args.output_dir, f'frozen_{args.arch}_{args.dataset}.json')
    with open(out_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved: {out_path}")
    print(f"Frozen baseline: {results['mean']:.4f} +/- {results['std']:.4f}")


if __name__ == '__main__':
    main()