""" Train ViT-Mini with block-level DFA on CIFAR-10 and SAVE the final checkpoint + the random feedback Bs. The existing snapshot_evolution_vit.py and vit_frozen_blocks_baseline.py scripts do not save model checkpoints, which means the protocol cannot be applied to a trained ViT post-hoc. Output: results/vit_dfa_checkpoints/dfa_vit_s{seed}.pt — state_dict + Bs Run: CUDA_VISIBLE_DEVICES=2 python experiments/train_vit_dfa_save_checkpoint.py --seed 42 --epochs 60 """ import sys, os, argparse sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import numpy as np from models.vit_mini import ViTMini def get_loaders(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), ) def evaluate(model, loader, dev): model.eval() n = c = 0 with torch.no_grad(): for x, y in loader: x, y = x.to(dev), y.to(dev) preds = model(x).argmax(-1) c += (preds == y).sum().item() n += x.size(0) return c / n def train_dfa_vit(model, train_loader, test_loader, dev, epochs, lr, wd): d_model = model.d_hidden L = model.num_blocks C = 10 Bs = [torch.randn(d_model, C, device=dev) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] embed_opt = optim.AdamW( list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed], lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [ optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs), ] for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x, y = x.to(dev), y.to(dev) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() h_cls = model.out_ln(hL_det[:, 0]) head_opt.zero_grad() F.cross_entropy(model.out_head(h_cls), y).backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() a_dfa_b = a_dfa.unsqueeze(1).expand_as(h_l) rms = (a_dfa_b ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a_dfa_b / rms f_l = model.blocks[l](h_l) local = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() a0 = (e_T @ Bs[0].T).detach() rms0 = (a0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) a0_b = a0.unsqueeze(1).expand_as(h0) embed_loss = (h0 * (a0_b / rms0.unsqueeze(1))).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() for s in scheds: s.step() if ep % 10 == 0 or ep == 1 or ep == epochs: acc = evaluate(model, test_loader, dev) print(f" ep {ep}: test_acc={acc:.4f}", flush=True) return Bs def main(): p = argparse.ArgumentParser() p.add_argument('--seed', type=int, default=42) p.add_argument('--epochs', type=int, default=60) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--wd', type=float, default=0.05) p.add_argument('--output_dir', type=str, default='results/vit_dfa_checkpoints') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) dev = torch.device('cuda:0') print(f"Train ViT-Mini DFA: seed={args.seed} epochs={args.epochs}", flush=True) train_loader, test_loader = get_loaders(batch_size=128) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ViTMini(num_blocks=4, d_model=128, n_heads=4).to(dev) Bs = train_dfa_vit(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd) final_acc = evaluate(m, test_loader, dev) print(f"FINAL test acc: {final_acc:.4f}", flush=True) out_path = os.path.join(args.output_dir, f"dfa_vit_s{args.seed}.pt") torch.save({ "state_dict": m.state_dict(), "Bs": [b.cpu() for b in Bs], "config": vars(args), "test_acc": final_acc, }, out_path) print(f"Saved {out_path}") if __name__ == "__main__": main()