""" Frozen-blocks baselines for ViT-Mini and StudentNet. Trains only embed/head/LN with blocks frozen at random init. Also trains shallow (no blocks) variant for comparison. """ import os, sys, json, argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import torchvision, torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.vit_mini import ViTMini from experiments.confirmatory_paper_experiments import ( StudentNet, TeacherNet, generate_synth_dataset, set_seed ) def get_cifar10(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) def evaluate(model, loader, device, is_vit=False): model.eval() c = n = 0 with torch.no_grad(): for x, y in loader: x = x.to(device); y = y.to(device) if not is_vit: x = x.view(x.size(0), -1) if x.dim() == 4 else x preds = model(x).argmax(-1) c += (preds == y).sum().item() n += x.size(0) return c / n def freeze_blocks(model): for p in model.blocks.parameters(): p.requires_grad_(False) # ─── ViT-Mini frozen/shallow ──────────────────────────────────────────── def train_vit_frozen(seed, train_loader, test_loader, device, epochs, lr, wd): torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) model = ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=10).to(device) freeze_blocks(model) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f" ViT-Mini frozen: {trainable}/{total} trainable params", flush=True) opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.to(device); y = y.to(device) loss = F.cross_entropy(model(x), y) opt.zero_grad(); loss.backward(); opt.step() sch.step() if ep % 10 == 0 or ep == epochs: acc = evaluate(model, test_loader, device, is_vit=True) print(f" [ViT-frozen] s={seed} ep {ep}: acc={acc:.4f}", flush=True) return evaluate(model, test_loader, device, is_vit=True) def train_vit_shallow(seed, train_loader, test_loader, device, epochs, lr, wd): """ViT with num_blocks=0: just patch_embed + cls + pos + LN + head.""" torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) model = ViTMini(d_model=128, n_heads=4, num_blocks=0, num_classes=10).to(device) trainable = sum(p.numel() for p in model.parameters()) print(f" ViT-Mini shallow: {trainable} params (no blocks)", flush=True) opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.to(device); y = y.to(device) loss = F.cross_entropy(model(x), y) opt.zero_grad(); loss.backward(); opt.step() sch.step() if ep % 10 == 0 or ep == epochs: acc = evaluate(model, test_loader, device, is_vit=True) print(f" [ViT-shallow] s={seed} ep {ep}: acc={acc:.4f}", flush=True) return evaluate(model, test_loader, device, is_vit=True) # ─── StudentNet frozen/shallow ────────────────────────────────────────── def train_student_frozen(seed, train_loader, test_loader, device, epochs, lr, wd, alpha=1.0): set_seed(seed) model = StudentNet(128, 10, 4, alpha).to(device) freeze_blocks(model) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f" StudentNet frozen: {trainable}/{total} trainable params", flush=True) opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.to(device); y = y.to(device) loss = F.cross_entropy(model(x), y) opt.zero_grad(); loss.backward(); opt.step() sch.step() if ep % 10 == 0 or ep == epochs: acc = evaluate(model, test_loader, device) print(f" [Student-frozen] s={seed} ep {ep}: acc={acc:.4f}", flush=True) return evaluate(model, test_loader, device) def train_student_shallow(seed, train_loader, test_loader, device, epochs, lr, wd, alpha=1.0): """StudentNet with num_blocks=0: just out_head (input is d_hidden already).""" set_seed(seed) model = StudentNet(128, 10, 0, alpha).to(device) trainable = sum(p.numel() for p in model.parameters()) print(f" StudentNet shallow: {trainable} params (no blocks)", flush=True) opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.to(device); y = y.to(device) loss = F.cross_entropy(model(x), y) opt.zero_grad(); loss.backward(); opt.step() sch.step() if ep % 10 == 0 or ep == epochs: acc = evaluate(model, test_loader, device) print(f" [Student-shallow] s={seed} ep {ep}: acc={acc:.4f}", flush=True) return evaluate(model, test_loader, device) def main(): p = argparse.ArgumentParser() p.add_argument('--output', type=str, default='results/frozen_baselines_crossarch.json') args = p.parse_args() device = torch.device('cuda:0') results = {} # ── ViT-Mini (CIFAR-10, 60 epochs) ── print("\n=== ViT-Mini frozen baselines ===", flush=True) train_loader, test_loader = get_cifar10(128) for seed in [42, 123, 456]: print(f"\n--- ViT-Mini seed={seed} ---", flush=True) frozen_acc = train_vit_frozen(seed, train_loader, test_loader, device, 60, 1e-3, 0.05) shallow_acc = train_vit_shallow(seed, train_loader, test_loader, device, 60, 1e-3, 0.05) results[f'vit_frozen_s{seed}'] = frozen_acc results[f'vit_shallow_s{seed}'] = shallow_acc print(f" FINAL ViT s={seed}: frozen={frozen_acc:.4f}, shallow={shallow_acc:.4f}", flush=True) # ── StudentNet (synthetic, 80 epochs) ── print("\n=== StudentNet frozen baselines ===", flush=True) L, d, C, alpha = 4, 128, 10, 1.0 for seed in [42, 123, 456]: print(f"\n--- StudentNet seed={seed} ---", flush=True) set_seed(seed) teacher = TeacherNet(d, L, C, alpha, seed=0).to(device) X_tr, Y_tr = generate_synth_dataset(teacher, 50*256, d, device, seed=seed) X_te, Y_te = generate_synth_dataset(teacher, 2000, d, device, seed=seed+10000) s_train = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True) s_test = DataLoader(TensorDataset(X_te, Y_te), batch_size=256, shuffle=False) frozen_acc = train_student_frozen(seed, s_train, s_test, device, 80, 1e-3, 0.01, alpha) shallow_acc = train_student_shallow(seed, s_train, s_test, device, 80, 1e-3, 0.01, alpha) results[f'student_frozen_s{seed}'] = frozen_acc results[f'student_shallow_s{seed}'] = shallow_acc print(f" FINAL Student s={seed}: frozen={frozen_acc:.4f}, shallow={shallow_acc:.4f}", flush=True) with open(args.output, 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved: {args.output}", flush=True) if __name__ == '__main__': main()