summaryrefslogtreecommitdiff
path: root/reproduce/frozen_baseline.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
commitb480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch)
treef0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /reproduce/frozen_baseline.py
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'reproduce/frozen_baseline.py')
-rw-r--r--reproduce/frozen_baseline.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/reproduce/frozen_baseline.py b/reproduce/frozen_baseline.py
new file mode 100644
index 0000000..08368a2
--- /dev/null
+++ b/reproduce/frozen_baseline.py
@@ -0,0 +1,86 @@
+"""
+Frozen-blocks baseline: train only embed/head with blocks frozen at random init.
+
+Usage:
+ python reproduce/frozen_baseline.py --arch resmlp --seeds 42 123 456 --epochs 100
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from reproduce.train_methods import get_data, evaluate, make_model
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+
+
+def train_frozen(model, train_loader, test_loader, device, epochs, is_conv):
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=0.01)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for m in model.blocks.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
+ m.eval()
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ if not is_conv: x = x.view(x.size(0), -1)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device, is_conv)
+ print(f" [Frozen] ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device, is_conv)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--arch', type=str, default='resmlp', choices=['resmlp', 'resmlp_d512_L2', 'vit', 'resnet'])
+ p.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
+ p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456])
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/frozen_baselines')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ train_loader, test_loader, num_classes = get_data(args.dataset, 128)
+
+ results = {}
+ for seed in args.seeds:
+ print(f"\n--- Frozen baseline seed={seed} ---", flush=True)
+ torch.manual_seed(seed); np.random.seed(seed)
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
+ model, is_conv = make_model(args.arch, num_classes, device)
+ freeze_blocks(model)
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ total = sum(p.numel() for p in model.parameters())
+ print(f" {trainable}/{total} trainable params", flush=True)
+ acc = train_frozen(model, train_loader, test_loader, device, args.epochs, is_conv)
+ results[f's{seed}'] = acc
+ print(f" FINAL: {acc:.4f}", flush=True)
+
+ results['config'] = vars(args)
+ results['mean'] = float(np.mean([results[f's{s}'] for s in args.seeds]))
+ results['std'] = float(np.std([results[f's{s}'] for s in args.seeds], ddof=1))
+ out_path = os.path.join(args.output_dir, f'frozen_{args.arch}_{args.dataset}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {out_path}")
+ print(f"Frozen baseline: {results['mean']:.4f} +/- {results['std']:.4f}")
+
+
+if __name__ == '__main__':
+ main()