diff options
Diffstat (limited to 'experiments/vit_shallow_baseline.py')
| -rw-r--r-- | experiments/vit_shallow_baseline.py | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/experiments/vit_shallow_baseline.py b/experiments/vit_shallow_baseline.py new file mode 100644 index 0000000..c030d74 --- /dev/null +++ b/experiments/vit_shallow_baseline.py @@ -0,0 +1,147 @@ +""" +Shallow baseline for ViT-Mini: train BP and DFA on a 0-block ViT (just patch_embed ++ cls + pos + out_ln + out_head), to test whether the DFA accuracy on the full +ViT is just exploiting the patch embedder + head. + +This is the codex-round-5 control for the "DFA actually trains the transformer +blocks" claim. If shallow DFA acc ≈ 24% (matching the 4-block ViT-Mini DFA acc), +then the blocks are passengers and the claim is too strong. If shallow DFA acc +is much lower, then the blocks are doing real work. + +Usage: + CUDA_VISIBLE_DEVICES=2 python experiments/vit_shallow_baseline.py +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.vit_mini import ViTMini + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def train_bp_shallow(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + m = ViTMini(num_blocks=0, d_model=128, n_heads=4).to(dev) + print(f"BP-shallow: n_params={sum(p.numel() for p in m.parameters())}", flush=True) + opt = optim.AdamW(m.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs + 1): + m.train() + for x, y in train_loader: + x = x.to(dev); y = y.to(dev) + loss = F.cross_entropy(m(x), y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + if ep % 5 == 0 or ep == 1 or ep == epochs: + acc = evaluate(m, test_loader, dev) + print(f" BP-shallow ep {ep}: test_acc={acc:.4f}", flush=True) + return m + + +def train_dfa_shallow(train_loader, test_loader, dev, epochs=30, seed=42, lr=1e-3, wd=0.05): + """0-block ViT trained DFA-style: head with true CE on cls token, + embed (patch_embed + cls + pos) with random feedback `e_T @ B^T` from the head.""" + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + m = ViTMini(num_blocks=0, d_model=128, n_heads=4).to(dev) + print(f"DFA-shallow: n_params={sum(p.numel() for p in m.parameters())}", flush=True) + d_model, C = 128, 10 + B0 = torch.randn(d_model, C, device=dev) / np.sqrt(C) + embed_opt = optim.AdamW( + list(m.patch_embed.parameters()) + [m.cls_token, m.pos_embed], + lr=lr, weight_decay=wd + ) + head_opt = optim.AdamW( + list(m.out_head.parameters()) + list(m.out_ln.parameters()), + lr=lr, weight_decay=wd + ) + sch1 = optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs) + sch2 = optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs) + for ep in range(1, epochs + 1): + m.train() + for x, y in train_loader: + x = x.to(dev); y = y.to(dev) + with torch.no_grad(): + logits, hi = m(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(x.size(0)), y] -= 1 + hL_det = hi[-1].detach() + # Head update via true CE on cls token + h_cls = m.out_ln(hL_det[:, 0]) + head_opt.zero_grad() + F.cross_entropy(m.out_head(h_cls), y).backward() + head_opt.step() + # Embed update via DFA-style local loss + a0 = (e_T @ B0.T).detach() + rms = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = m.embed(x) # (B, 65, d_model) + a0_b = a0.unsqueeze(1).expand_as(h0) + embed_loss = (h0 * (a0_b / rms.unsqueeze(1))).sum(-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + sch1.step(); sch2.step() + if ep % 5 == 0 or ep == 1 or ep == epochs: + acc = evaluate(m, test_loader, dev) + print(f" DFA-shallow ep {ep}: test_acc={acc:.4f}", flush=True) + return m + + +def main(): + dev = torch.device('cuda:0') + print(f"Device: {dev}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + + print("\n=== BP shallow baseline (ViT-Mini num_blocks=0) ===", flush=True) + mb = train_bp_shallow(train_loader, test_loader, dev, epochs=30, seed=42) + bp_acc = evaluate(mb, test_loader, dev) + print(f"FINAL BP-shallow acc: {bp_acc:.4f}", flush=True) + + print("\n=== DFA shallow baseline (ViT-Mini num_blocks=0) ===", flush=True) + md = train_dfa_shallow(train_loader, test_loader, dev, epochs=30, seed=42) + dfa_acc = evaluate(md, test_loader, dev) + print(f"FINAL DFA-shallow acc: {dfa_acc:.4f}", flush=True) + + print(f"\n=== Summary ===") + print(f"BP-shallow: {bp_acc:.4f} (chance=0.10)") + print(f"DFA-shallow: {dfa_acc:.4f}") + print(f"Compare to ViT-Mini 4-block (3-seed avg): BP=0.792, DFA=0.237") + + +if __name__ == '__main__': + main() |
