diff options
Diffstat (limited to 'reproduce/frozen_baseline.py')
| -rw-r--r-- | reproduce/frozen_baseline.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/reproduce/frozen_baseline.py b/reproduce/frozen_baseline.py new file mode 100644 index 0000000..08368a2 --- /dev/null +++ b/reproduce/frozen_baseline.py @@ -0,0 +1,86 @@ +""" +Frozen-blocks baseline: train only embed/head with blocks frozen at random init. + +Usage: + python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100 +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision, torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from reproduce.train_methods import get_data, evaluate, make_model + + +def freeze_blocks(model): + for p in model.blocks.parameters(): + p.requires_grad_(False) + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + + +def train_frozen(model, train_loader, test_loader, device, epochs, is_conv): + opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=0.01) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + model.train() + for m in model.blocks.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): + m.eval() + for x, y in train_loader: + x, y = x.to(device), y.to(device) + if not is_conv: x = x.view(x.size(0), -1) + loss = F.cross_entropy(model(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 10 == 0 or ep == epochs: + acc = evaluate(model, test_loader, device, is_conv) + print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True) + return evaluate(model, test_loader, device, is_conv) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet']) + p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) + p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) + p.add_argument('--epochs', type=int, default=100) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/frozen_baselines') + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + train_loader, test_loader, num_classes = get_data(args.dataset, 128) + + results = {} + for seed in args.seeds: + print(f"\n--- Frozen baseline seed={seed} ---", flush=True) + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + model, is_conv = make_model(args.arch, num_classes, device) + freeze_blocks(model) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + print(f" {trainable}/{total} trainable params", flush=True) + acc = train_frozen(model, train_loader, test_loader, device, args.epochs, is_conv) + results[f's{seed}'] = acc + print(f" FINAL: {acc:.4f}", flush=True) + + results['config'] = vars(args) + results['mean'] = float(np.mean([results[f's{s}'] for s in args.seeds])) + results['std'] = float(np.std([results[f's{s}'] for s in args.seeds], ddof=1)) + out_path = os.path.join(args.output_dir, f'frozen_{args.arch}_{args.dataset}.json') + with open(out_path, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {out_path}") + print(f"Frozen baseline: {results['mean']:.4f} +/- {results['std']:.4f}") + + +if __name__ == '__main__': + main() |
