summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:10:56 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:10:56 -0500
commit89fff0048c04bdc4c8beb6d11f8d5564d75cbb0c (patch)
tree97048ea3de1ae0662b82e87a7363a3a16e11fe60 /experiments
parent7fbbe2c18a08f0a6314dfe22dc8790462252050a (diff)
Add ViT-Mini DFA training script that saves checkpoint + Bs
The existing snapshot_evolution_vit.py and vit_frozen_blocks_baseline.py do not save model checkpoints — they only emit per-epoch JSON logs. This makes it impossible to apply the diagnostic protocol to a trained ViT post-hoc, since the protocol needs an actual model object. This script trains a 4-block d=128 ViT-Mini with block-level DFA on CIFAR-10 (same training rule as snapshot_evolution_vit.py) for 60 epochs and saves: - the final state_dict - the random feedback Bs (so the protocol can also verify bug 4 on this checkpoint) - test_acc and config Output: results/vit_dfa_checkpoints/dfa_vit_s{seed}.pt
Diffstat (limited to 'experiments')
-rw-r--r--experiments/train_vit_dfa_save_checkpoint.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/experiments/train_vit_dfa_save_checkpoint.py b/experiments/train_vit_dfa_save_checkpoint.py
new file mode 100644
index 0000000..048220d
--- /dev/null
+++ b/experiments/train_vit_dfa_save_checkpoint.py
@@ -0,0 +1,143 @@
+"""
+Train ViT-Mini with block-level DFA on CIFAR-10 and SAVE the final checkpoint
++ the random feedback Bs. The existing snapshot_evolution_vit.py and
+vit_frozen_blocks_baseline.py scripts do not save model checkpoints, which
+means the protocol cannot be applied to a trained ViT post-hoc.
+
+Output:
+ results/vit_dfa_checkpoints/dfa_vit_s{seed}.pt — state_dict + Bs
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python experiments/train_vit_dfa_save_checkpoint.py --seed 42 --epochs 60
+"""
+import sys, os, argparse
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import numpy as np
+
+from models.vit_mini import ViTMini
+
+
+def get_loaders(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (
+ DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2),
+ )
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ n = c = 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(dev), y.to(dev)
+ preds = model(x).argmax(-1)
+ c += (preds == y).sum().item()
+ n += x.size(0)
+ return c / n
+
+
+def train_dfa_vit(model, train_loader, test_loader, dev, epochs, lr, wd):
+ d_model = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d_model, C, device=dev) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(
+ list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed],
+ lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [
+ optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs),
+ ]
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x, y = x.to(dev), y.to(dev)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ h_cls = model.out_ln(hL_det[:, 0])
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(h_cls), y).backward()
+ head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ a_dfa_b = a_dfa.unsqueeze(1).expand_as(h_l)
+ rms = (a_dfa_b ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa_b / rms
+ f_l = model.blocks[l](h_l)
+ local = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ a0_b = a0.unsqueeze(1).expand_as(h0)
+ embed_loss = (h0 * (a0_b / rms0.unsqueeze(1))).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+ for s in scheds: s.step()
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ acc = evaluate(model, test_loader, dev)
+ print(f" ep {ep}: test_acc={acc:.4f}", flush=True)
+ return Bs
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--epochs', type=int, default=60)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.05)
+ p.add_argument('--output_dir', type=str, default='results/vit_dfa_checkpoints')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ dev = torch.device('cuda:0')
+ print(f"Train ViT-Mini DFA: seed={args.seed} epochs={args.epochs}", flush=True)
+ train_loader, test_loader = get_loaders(batch_size=128)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev)
+ Bs = train_dfa_vit(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd)
+ final_acc = evaluate(m, test_loader, dev)
+ print(f"FINAL test acc: {final_acc:.4f}", flush=True)
+ out_path = os.path.join(args.output_dir, f"dfa_vit_s{args.seed}.pt")
+ torch.save({
+ "state_dict": m.state_dict(),
+ "Bs": [b.cpu() for b in Bs],
+ "config": vars(args),
+ "test_acc": final_acc,
+ }, out_path)
+ print(f"Saved {out_path}")
+
+
+if __name__ == "__main__":
+ main()