""" Frozen-blocks baseline: train only embed/head with blocks frozen at random init. Usage: python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100 """ import os, sys, json, argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision, torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from reproduce.train_methods import get_data, evaluate, make_model def freeze_blocks(model): for p in model.blocks.parameters(): p.requires_grad_(False) for m in model.blocks.modules(): if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.eval() def train_frozen(model, train_loader, test_loader, device, epochs, is_conv): opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=0.01) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for m in model.blocks.modules(): if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.eval() for x, y in train_loader: x, y = x.to(device), y.to(device) if not is_conv: x = x.view(x.size(0), -1) loss = F.cross_entropy(model(x), y) opt.zero_grad(); loss.backward(); opt.step() sch.step() if ep % 10 == 0 or ep == epochs: acc = evaluate(model, test_loader, device, is_conv) print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True) return evaluate(model, test_loader, device, is_conv) def main(): p = argparse.ArgumentParser() p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet']) p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100']) p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) p.add_argument('--epochs', type=int, default=100) p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/frozen_baselines') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') train_loader, test_loader, num_classes = get_data(args.dataset, 128) results = {} for seed in args.seeds: print(f"\n--- Frozen baseline seed={seed} ---", flush=True) torch.manual_seed(seed); np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) model, is_conv = make_model(args.arch, num_classes, device) freeze_blocks(model) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f" {trainable}/{total} trainable params", flush=True) acc = train_frozen(model, train_loader, test_loader, device, args.epochs, is_conv) results[f's{seed}'] = acc print(f" FINAL: {acc:.4f}", flush=True) results['config'] = vars(args) results['mean'] = float(np.mean([results[f's{s}'] for s in args.seeds])) results['std'] = float(np.std([results[f's{s}'] for s in args.seeds], ddof=1)) out_path = os.path.join(args.output_dir, f'frozen_{args.arch}_{args.dataset}.json') with open(out_path, 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved: {out_path}") print(f"Frozen baseline: {results['mean']:.4f} +/- {results['std']:.4f}") if __name__ == '__main__': main()