diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:10:56 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:10:56 -0500 |
| commit | 89fff0048c04bdc4c8beb6d11f8d5564d75cbb0c (patch) | |
| tree | 97048ea3de1ae0662b82e87a7363a3a16e11fe60 /experiments | |
| parent | 7fbbe2c18a08f0a6314dfe22dc8790462252050a (diff) | |
Add ViT-Mini DFA training script that saves checkpoint + Bs
The existing snapshot_evolution_vit.py and vit_frozen_blocks_baseline.py
do not save model checkpoints — they only emit per-epoch JSON logs. This
makes it impossible to apply the diagnostic protocol to a trained ViT
post-hoc, since the protocol needs an actual model object.
This script trains a 4-block d=128 ViT-Mini with block-level DFA on
CIFAR-10 (same training rule as snapshot_evolution_vit.py) for 60 epochs
and saves:
- the final state_dict
- the random feedback Bs (so the protocol can also verify bug 4 on
this checkpoint)
- test_acc and config
Output: results/vit_dfa_checkpoints/dfa_vit_s{seed}.pt
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/train_vit_dfa_save_checkpoint.py | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/experiments/train_vit_dfa_save_checkpoint.py b/experiments/train_vit_dfa_save_checkpoint.py new file mode 100644 index 0000000..048220d --- /dev/null +++ b/experiments/train_vit_dfa_save_checkpoint.py @@ -0,0 +1,143 @@ +""" +Train ViT-Mini with block-level DFA on CIFAR-10 and SAVE the final checkpoint ++ the random feedback Bs. The existing snapshot_evolution_vit.py and +vit_frozen_blocks_baseline.py scripts do not save model checkpoints, which +means the protocol cannot be applied to a trained ViT post-hoc. + +Output: + results/vit_dfa_checkpoints/dfa_vit_s{seed}.pt — state_dict + Bs + +Run: + CUDA_VISIBLE_DEVICES=2 python experiments/train_vit_dfa_save_checkpoint.py --seed 42 --epochs 60 +""" +import sys, os, argparse +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import numpy as np + +from models.vit_mini import ViTMini + + +def get_loaders(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return ( + DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), + ) + + +def evaluate(model, loader, dev): + model.eval() + n = c = 0 + with torch.no_grad(): + for x, y in loader: + x, y = x.to(dev), y.to(dev) + preds = model(x).argmax(-1) + c += (preds == y).sum().item() + n += x.size(0) + return c / n + + +def train_dfa_vit(model, train_loader, test_loader, dev, epochs, lr, wd): + d_model = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d_model, C, device=dev) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW( + list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed], + lr=lr, weight_decay=wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [ + optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs), + ] + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x, y = x.to(dev), y.to(dev) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + h_cls = model.out_ln(hL_det[:, 0]) + head_opt.zero_grad() + F.cross_entropy(model.out_head(h_cls), y).backward() + head_opt.step() + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() + a_dfa_b = a_dfa.unsqueeze(1).expand_as(h_l) + rms = (a_dfa_b ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_dfa_b / rms + f_l = model.blocks[l](h_l) + local = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + a0_b = a0.unsqueeze(1).expand_as(h0) + embed_loss = (h0 * (a0_b / rms0.unsqueeze(1))).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + for s in scheds: s.step() + if ep % 10 == 0 or ep == 1 or ep == epochs: + acc = evaluate(model, test_loader, dev) + print(f" ep {ep}: test_acc={acc:.4f}", flush=True) + return Bs + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--seed', type=int, default=42) + p.add_argument('--epochs', type=int, default=60) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--wd', type=float, default=0.05) + p.add_argument('--output_dir', type=str, default='results/vit_dfa_checkpoints') + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + dev = torch.device('cuda:0') + print(f"Train ViT-Mini DFA: seed={args.seed} epochs={args.epochs}", flush=True) + train_loader, test_loader = get_loaders(batch_size=128) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev) + Bs = train_dfa_vit(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd) + final_acc = evaluate(m, test_loader, dev) + print(f"FINAL test acc: {final_acc:.4f}", flush=True) + out_path = os.path.join(args.output_dir, f"dfa_vit_s{args.seed}.pt") + torch.save({ + "state_dict": m.state_dict(), + "Bs": [b.cpu() for b in Bs], + "config": vars(args), + "test_acc": final_acc, + }, out_path) + print(f"Saved {out_path}") + + +if __name__ == "__main__": + main() |
