summaryrefslogtreecommitdiff
path: root/experiments/vit_frozen_blocks_baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/vit_frozen_blocks_baseline.py')
-rw-r--r--experiments/vit_frozen_blocks_baseline.py177
1 files changed, 177 insertions, 0 deletions
diff --git a/experiments/vit_frozen_blocks_baseline.py b/experiments/vit_frozen_blocks_baseline.py
new file mode 100644
index 0000000..8b53198
--- /dev/null
+++ b/experiments/vit_frozen_blocks_baseline.py
@@ -0,0 +1,177 @@
+"""
+Frozen-random-blocks baseline for ViT-Mini: train BP and DFA where the 4
+transformer blocks are randomly initialized and FROZEN (no parameter updates).
+Only patch_embed + cls_token + pos_embed + out_ln + out_head are trainable.
+
+This is the codex-round-6 control for the "DFA actually trains the transformer
+blocks" claim. If frozen-blocks DFA gets ≈ 24% (matching the trainable-blocks
+4-block ViT-Mini DFA acc), then the blocks are passengers — DFA's "24%" is
+coming from patch_embed + head learning routed via untrained block mixing.
+If frozen-blocks DFA stays much lower than 24%, then the trainable blocks
+are doing learned work.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 python experiments/vit_frozen_blocks_baseline.py
+"""
+import sys, os
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.vit_mini import ViTMini
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(dev), y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def freeze_blocks(model):
+ for p in model.blocks.parameters():
+ p.requires_grad_(False)
+ model.blocks.eval()
+
+
+def train_bp_frozen(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev)
+ freeze_blocks(m)
+ n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
+ n_total = sum(p.numel() for p in m.parameters())
+ print(f"BP-frozen-blocks: {n_trainable}/{n_total} params trainable", flush=True)
+ opt = optim.AdamW(filter(lambda p: p.requires_grad, m.parameters()), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ m.train()
+ m.blocks.eval() # keep blocks in eval mode (no dropout etc)
+ for x, y in train_loader:
+ x = x.to(dev); y = y.to(dev)
+ loss = F.cross_entropy(m(x), y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(m, test_loader, dev)
+ print(f" BP-frozen ep {ep}: test_acc={acc:.4f}", flush=True)
+ return m
+
+
+def train_dfa_frozen(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05):
+ """4 transformer blocks frozen at random init.
+ Trainable: patch_embed, cls_token, pos_embed, out_ln, out_head.
+ DFA-style: head with true CE on cls token; embed (patch+cls+pos) with random feedback."""
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev)
+ freeze_blocks(m)
+ n_trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
+ n_total = sum(p.numel() for p in m.parameters())
+ print(f"DFA-frozen-blocks: {n_trainable}/{n_total} params trainable", flush=True)
+
+ d_model, C = 128, 10
+ B0 = torch.randn(d_model, C, device=dev) / np.sqrt(C)
+ embed_opt = optim.AdamW(
+ list(m.patch_embed.parameters()) + [m.cls_token, m.pos_embed],
+ lr=lr, weight_decay=wd
+ )
+ head_opt = optim.AdamW(
+ list(m.out_head.parameters()) + list(m.out_ln.parameters()),
+ lr=lr, weight_decay=wd
+ )
+ sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs)
+ sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ m.train()
+ m.blocks.eval()
+ for x, y in train_loader:
+ x = x.to(dev); y = y.to(dev)
+ with torch.no_grad():
+ logits, hi = m(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1
+ hL_det = hi[-1].detach()
+ # Head update via true CE on cls token
+ h_cls = m.out_ln(hL_det[:, 0])
+ head_opt.zero_grad()
+ F.cross_entropy(m.out_head(h_cls), y).backward()
+ head_opt.step()
+ # Embed update via DFA feedback
+ a0 = (e_T @ B0.T).detach()
+ rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = m.embed(x)
+ a0_b = a0.unsqueeze(1).expand_as(h0)
+ embed_loss = (h0 * (a0_b / rms.unsqueeze(1))).sum(-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+ sch1.step(); sch2.step()
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(m, test_loader, dev)
+ print(f" DFA-frozen ep {ep}: test_acc={acc:.4f}", flush=True)
+ return m
+
+
+def main():
+ import argparse
+ p = argparse.ArgumentParser()
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--epochs', type=int, default=30)
+ args = p.parse_args()
+
+ dev = torch.device('cuda:0')
+ print(f"Device: {dev}, seed={args.seed}, epochs={args.epochs}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+
+ print(f"\n=== BP frozen-blocks baseline (4 random-init transformer blocks, frozen), seed={args.seed} ===", flush=True)
+ mb = train_bp_frozen(train_loader, test_loader, dev, epochs=args.epochs, seed=args.seed)
+ bp_acc = evaluate(mb, test_loader, dev)
+ print(f"FINAL BP-frozen-blocks acc: {bp_acc:.4f}", flush=True)
+
+ print(f"\n=== DFA frozen-blocks baseline, seed={args.seed} ===", flush=True)
+ md = train_dfa_frozen(train_loader, test_loader, dev, epochs=args.epochs, seed=args.seed)
+ dfa_acc = evaluate(md, test_loader, dev)
+ print(f"FINAL DFA-frozen-blocks acc: {dfa_acc:.4f}", flush=True)
+
+ print(f"\n=== Summary ===")
+ print(f"BP-frozen-blocks: {bp_acc:.4f} (chance=0.10)")
+ print(f"DFA-frozen-blocks: {dfa_acc:.4f}")
+ print(f"Compare to ViT-Mini 4-block trainable (3-seed avg): BP=0.792, DFA=0.237")
+ print(f"Compare to ViT-Mini 0-block (shallow baseline): BP=0.10, DFA=0.10")
+ print()
+ print("Interpretation:")
+ print(" If DFA-frozen-blocks ≈ 0.237: blocks are passengers, DFA is just learning patch_embed+head")
+ print(" If DFA-frozen-blocks << 0.237: trainable blocks ARE doing learned work")
+ print(" If DFA-frozen-blocks ~ 0.10: untrained blocks add no useful mixing (less informative)")
+
+
+if __name__ == '__main__':
+ main()