summaryrefslogtreecommitdiff
path: root/experiments/frozen_baselines_crossarch.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/frozen_baselines_crossarch.py')
-rw-r--r--experiments/frozen_baselines_crossarch.py191
1 files changed, 191 insertions, 0 deletions
diff --git a/experiments/frozen_baselines_crossarch.py b/experiments/frozen_baselines_crossarch.py
new file mode 100644
index 0000000..a3dd76c
--- /dev/null
+++ b/experiments/frozen_baselines_crossarch.py
@@ -0,0 +1,191 @@
+"""
+Frozen-blocks baselines for ViT-Mini and StudentNet.
+Trains only embed/head/LN with blocks frozen at random init.
+Also trains shallow (no blocks) variant for comparison.
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.vit_mini import ViTMini
+from experiments.confirmatory_paper_experiments import (
+ StudentNet, TeacherNet, generate_synth_dataset, set_seed
+)
+
+
+def get_cifar10(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def evaluate(model, loader, device, is_vit=False):
+ model.eval()
+ c = n = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x = x.to(device); y = y.to(device)
+ if not is_vit:
+ x = x.view(x.size(0), -1) if x.dim() == 4 else x
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+
+
+# ─── ViT-Mini frozen/shallow ────────────────────────────────────────────
+
+def train_vit_frozen(seed, train_loader, test_loader, device, epochs, lr, wd):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=10).to(device)
+ freeze_blocks(model)
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ total = sum(p.numel() for p in model.parameters())
+ print(f" ViT-Mini frozen: {trainable}/{total} trainable params", flush=True)
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device, is_vit=True)
+ print(f" [ViT-frozen] s={seed} ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device, is_vit=True)
+
+
+def train_vit_shallow(seed, train_loader, test_loader, device, epochs, lr, wd):
+ """ViT with num_blocks=0: just patch_embed + cls + pos + LN + head."""
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ model = ViTMini(d_model=128, n_heads=4, num_blocks=0, num_classes=10).to(device)
+ trainable = sum(p.numel() for p in model.parameters())
+ print(f" ViT-Mini shallow: {trainable} params (no blocks)", flush=True)
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device, is_vit=True)
+ print(f" [ViT-shallow] s={seed} ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device, is_vit=True)
+
+
+# ─── StudentNet frozen/shallow ──────────────────────────────────────────
+
+def train_student_frozen(seed, train_loader, test_loader, device, epochs, lr, wd, alpha=1.0):
+ set_seed(seed)
+ model = StudentNet(128, 10, 4, alpha).to(device)
+ freeze_blocks(model)
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ total = sum(p.numel() for p in model.parameters())
+ print(f" StudentNet frozen: {trainable}/{total} trainable params", flush=True)
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device)
+ print(f" [Student-frozen] s={seed} ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device)
+
+
+def train_student_shallow(seed, train_loader, test_loader, device, epochs, lr, wd, alpha=1.0):
+ """StudentNet with num_blocks=0: just out_head (input is d_hidden already)."""
+ set_seed(seed)
+ model = StudentNet(128, 10, 0, alpha).to(device)
+ trainable = sum(p.numel() for p in model.parameters())
+ print(f" StudentNet shallow: {trainable} params (no blocks)", flush=True)
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 10 == 0 or ep == epochs:
+ acc = evaluate(model, test_loader, device)
+ print(f" [Student-shallow] s={seed} ep {ep}: acc={acc:.4f}", flush=True)
+ return evaluate(model, test_loader, device)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output', type=str, default='results/frozen_baselines_crossarch.json')
+ args = p.parse_args()
+
+ device = torch.device('cuda:0')
+
+ results = {}
+
+ # ── ViT-Mini (CIFAR-10, 60 epochs) ──
+ print("\n=== ViT-Mini frozen baselines ===", flush=True)
+ train_loader, test_loader = get_cifar10(128)
+ for seed in [42, 123, 456]:
+ print(f"\n--- ViT-Mini seed={seed} ---", flush=True)
+ frozen_acc = train_vit_frozen(seed, train_loader, test_loader, device, 60, 1e-3, 0.05)
+ shallow_acc = train_vit_shallow(seed, train_loader, test_loader, device, 60, 1e-3, 0.05)
+ results[f'vit_frozen_s{seed}'] = frozen_acc
+ results[f'vit_shallow_s{seed}'] = shallow_acc
+ print(f" FINAL ViT s={seed}: frozen={frozen_acc:.4f}, shallow={shallow_acc:.4f}", flush=True)
+
+ # ── StudentNet (synthetic, 80 epochs) ──
+ print("\n=== StudentNet frozen baselines ===", flush=True)
+ L, d, C, alpha = 4, 128, 10, 1.0
+ for seed in [42, 123, 456]:
+ print(f"\n--- StudentNet seed={seed} ---", flush=True)
+ set_seed(seed)
+ teacher = TeacherNet(d, L, C, alpha, seed=0).to(device)
+ X_tr, Y_tr = generate_synth_dataset(teacher, 50*256, d, device, seed=seed)
+ X_te, Y_te = generate_synth_dataset(teacher, 2000, d, device, seed=seed+10000)
+ s_train = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True)
+ s_test = DataLoader(TensorDataset(X_te, Y_te), batch_size=256, shuffle=False)
+
+ frozen_acc = train_student_frozen(seed, s_train, s_test, device, 80, 1e-3, 0.01, alpha)
+ shallow_acc = train_student_shallow(seed, s_train, s_test, device, 80, 1e-3, 0.01, alpha)
+ results[f'student_frozen_s{seed}'] = frozen_acc
+ results[f'student_shallow_s{seed}'] = shallow_acc
+ print(f" FINAL Student s={seed}: frozen={frozen_acc:.4f}, shallow={shallow_acc:.4f}", flush=True)
+
+ with open(args.output, 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {args.output}", flush=True)
+
+
+if __name__ == '__main__':
+ main()