diff options
Diffstat (limited to 'experiments/snapshot_evolution_vit.py')
| -rw-r--r-- | experiments/snapshot_evolution_vit.py | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/experiments/snapshot_evolution_vit.py b/experiments/snapshot_evolution_vit.py new file mode 100644 index 0000000..ce4c090 --- /dev/null +++ b/experiments/snapshot_evolution_vit.py @@ -0,0 +1,244 @@ +""" +Snapshot evolution on a ViT-Mini (modern transformer-style architecture) trained +with BP and block-level DFA on CIFAR-10. Logs ||h_l||, ||BP grad||, Γ per epoch. + +This is the P4 generalization test: does the residual-stream pathology + LayerNorm +gradient collapse mechanism (verified on pre-LN ResMLP with terminal LN) also +appear on an actual transformer architecture? If yes → strong P4 in modern setting. + +Block-level DFA: each TransformerBlock is a "layer". The DFA credit +`a_l = e_T @ B_l^T` is broadcast across all tokens at that block's input. The +local block loss is `<block_l(h_l), broadcast(a_l)>` summed over tokens. + +Usage: + CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_evolution_vit.py \ + --output_dir results/snapshot_vit_v1 --epochs 60 --seed 42 \ + > results/snapshot_vit_v1/run_s42.log 2>&1 & +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.vit_mini import ViTMini, TransformerBlock +from metrics.credit_metrics import cosine_similarity_batch + + +def get_cifar10(batch_size=128): + tv_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) + te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), + DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) + + +def fixed_eval_buffer(test_loader, device, n_samples=1024): + xs, ys = [], [] + for x, y in test_loader: + xs.append(x); ys.append(y) + if sum(xb.size(0) for xb in xs) >= n_samples: + break + return torch.cat(xs)[:n_samples].to(device), torch.cat(ys)[:n_samples].to(device) + + +def diagnose(model, x_eval, y_eval, dfa_Bs=None): + """Compute per-block ||h_l|| and ||BP grad at h_l||, plus optional Γ vs DFA credit.""" + was_training = model.training + model.eval() + L = model.num_blocks + + # Hidden states (no grad) + with torch.no_grad(): + _, hiddens = model(x_eval, return_hidden=True) + # hiddens[l] is shape (B, n_tokens, d_model) + # Reduce to per-sample by taking the cls-token norm OR by flattening across tokens + # We'll report cls-token norm (the one that actually flows to the head) + hidden_norms_cls = [h[:, 0].norm(dim=-1).median().item() for h in hiddens] + hidden_norms_avg = [h.norm(dim=-1).mean().item() for h in hiddens] # avg across tokens then over batch + + # BP gradients + h0 = model.embed(x_eval.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(b(hs[-1])) + h_cls = model.out_ln(hs[-1][:, 0]) + logits = model.out_head(h_cls) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + # grads[l] is shape (B, n_tokens, d_model) + # Per-sample L2 norm: take Frobenius over tokens × d_model + bp_grad_per_sample_l2 = [g.flatten(1).norm(dim=-1).median().item() for g in grads] + bp_grad_F = [g.norm().item() for g in grads] + bp_full = [g.detach() for g in grads] + + acc = (logits.argmax(-1) == y_eval).float().mean().item() + loss_val = loss.item() + + gamma_dfa = float('nan'); per_layer_gamma = [] + if dfa_Bs is not None: + with torch.no_grad(): + e_T = logits.softmax(-1); e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + for l in range(L): + # Block-level DFA credit: per-sample (B, d_model), broadcast to (B, n_tokens, d_model) + a_dfa_per_sample = (e_T @ dfa_Bs[l].T).detach() # (B, d_model) + a_dfa_broadcast = a_dfa_per_sample.unsqueeze(1).expand_as(bp_full[l]) # (B, n_tokens, d_model) + # Cosine using flattened (per-sample) representation + per_layer_gamma.append(cosine_similarity_batch( + a_dfa_broadcast.flatten(1), bp_full[l].flatten(1))) + gamma_dfa = float(np.mean(per_layer_gamma)) + + if was_training: + model.train() + + return { + 'hidden_norms_cls': hidden_norms_cls, + 'hidden_norms_avg': hidden_norms_avg, + 'bp_grad_per_sample_l2_med': bp_grad_per_sample_l2, + 'bp_grad_F': bp_grad_F, + 'gamma_dfa': gamma_dfa, + 'gamma_dfa_per_layer': per_layer_gamma, + 'acc_eval': acc, + 'loss_eval': loss_val, + } + + +def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + log = [] + d0 = diagnose(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0) + print(f" [BP-vit] Ep 0: ||h_L_cls||={d0['hidden_norms_cls'][-1]:.3e} ||g_2||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + logits = model(x); loss = F.cross_entropy(logits, y) + opt.zero_grad(); loss.backward(); opt.step() + sch.step() + d = diagnose(model, x_eval, y_eval); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep == 1 or ep == epochs: + print(f" [BP-vit] Ep {ep}: ||h_L_cls||={d['hidden_norms_cls'][-1]:.3e} ||g_2||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f}", flush=True) + return log + + +def train_dfa_block_level(model, train_loader, x_eval, y_eval, device, epochs, lr, wd): + """Block-level DFA on ViT. Each TransformerBlock is treated as a unit; DFA credit + is broadcast across all tokens at the block's input. + """ + d_model = model.d_hidden + L = model.num_blocks + C = 10 + Bs = [torch.randn(d_model, C, device=device) / np.sqrt(C) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW( + list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed], + lr=lr, weight_decay=wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + log = [] + d0 = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d0['epoch'] = 0; log.append(d0) + print(f" [DFA-vit] Ep 0: ||h_L_cls||={d0['hidden_norms_cls'][-1]:.3e} ||g_2||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True) + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL_det = hiddens[-1].detach() + # Head update via direct CE on cls token + h_cls = model.out_ln(hL_det[:, 0]) + logits_out = model.out_head(h_cls) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + # Block updates: each block's local loss = <block(h_l), a_dfa_broadcast> + for l in range(L): + h_l = hiddens[l].detach() # (B, n_tokens, d) + a_dfa = (e_T @ Bs[l].T).detach() # (B, d) + a_dfa_broadcast = a_dfa.unsqueeze(1).expand_as(h_l) # (B, n_tokens, d) + rms = (a_dfa_broadcast ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_dfa_broadcast / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad(); local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + # Embed update (patch embed + cls + pos) + a_0 = (e_T @ Bs[0].T).detach() + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) # (B, n_tokens, d) + a_0_broadcast = a_0.unsqueeze(1).expand_as(h0) + embed_loss = (h0 * (a_0_broadcast / rms_0.unsqueeze(1))).sum(dim=-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + for s in all_sch: s.step() + d = diagnose(model, x_eval, y_eval, dfa_Bs=Bs); d['epoch'] = ep; log.append(d) + if ep % 5 == 0 or ep == 1 or ep == epochs: + print(f" [DFA-vit] Ep {ep}: ||h_L_cls||={d['hidden_norms_cls'][-1]:.3e} ||g_2||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f} γ={d['gamma_dfa']:.4f}", flush=True) + return log + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--output_dir', type=str, default='results/snapshot_vit_v1') + p.add_argument('--epochs', type=int, default=60) + p.add_argument('--lr', type=float, default=1e-3) + p.add_argument('--wd', type=float, default=0.05) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--d_model', type=int, default=128) + p.add_argument('--n_heads', type=int, default=4) + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device('cuda:0') + print(f"ViT-MINI: depth={args.depth}, d_model={args.d_model}, n_heads={args.n_heads}, " + f"epochs={args.epochs}, seed={args.seed}", flush=True) + + train_loader, test_loader = get_cifar10(batch_size=128) + x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024) + + print("\n=== BP training (ViT-Mini) ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + bp_model = ViTMini(num_blocks=args.depth, d_model=args.d_model, n_heads=args.n_heads).to(device) + print(f" n_params={sum(p.numel() for p in bp_model.parameters())}", flush=True) + bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd) + + print("\n=== DFA training (ViT-Mini, block-level DFA) ===", flush=True) + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + dfa_model = ViTMini(num_blocks=args.depth, d_model=args.d_model, n_heads=args.n_heads).to(device) + dfa_log = train_dfa_block_level(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd) + + out = { + 'config': vars(args), 'depth': args.depth, 'd_model': args.d_model, + 'architecture': 'ViTMini', 'bp_log': bp_log, 'dfa_log': dfa_log, + } + out_path = os.path.join(args.output_dir, f'snapshot_vit_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(out, f, indent=2) + print(f"\nSaved {out_path}", flush=True) + + +if __name__ == '__main__': + main() |
