summaryrefslogtreecommitdiff
path: root/experiments/snapshot_fa_crossarch.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/snapshot_fa_crossarch.py')
-rw-r--r--experiments/snapshot_fa_crossarch.py243
1 files changed, 243 insertions, 0 deletions
diff --git a/experiments/snapshot_fa_crossarch.py b/experiments/snapshot_fa_crossarch.py
new file mode 100644
index 0000000..8fa9e71
--- /dev/null
+++ b/experiments/snapshot_fa_crossarch.py
@@ -0,0 +1,243 @@
+"""
+FA-only snapshot evolution for ViT-Mini and ResMLP-no-outLN.
+Produces per-epoch ||h_L||, ||g_L||, acc for FA training.
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision, torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from models.vit_mini import ViTMini
+
+
+def get_cifar10(batch_size=128):
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2))
+
+
+def fixed_eval_buffer(loader, device, n=1024):
+ xs, ys = [], []
+ for x, y in loader:
+ xs.append(x); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= n:
+ break
+ return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device)
+
+
+# ─── Diagnose (works for both ViT and ResMLP) ───────────────────────────
+
+def diagnose_resmlp(model, x_eval, y_eval):
+ model.eval()
+ x_flat = x_eval.view(x_eval.size(0), -1)
+ with torch.no_grad():
+ _, hiddens = model(x_flat, return_hidden=True)
+ hidden_norms = [h.norm(dim=-1).median().item() for h in hiddens]
+ # BP grads
+ h0 = model.embed(x_flat.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ # Handle both with and without out_ln
+ if hasattr(model, 'out_ln'):
+ logits = model.out_head(model.out_ln(hs[-1]))
+ else:
+ logits = model.out_head(hs[-1])
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_norms = [g.norm(dim=-1).median().item() for g in grads]
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ model.train()
+ return {'hidden_norms': hidden_norms, 'bp_grad_norms_per_sample_med': g_norms, 'acc_eval': acc}
+
+
+def diagnose_vit(model, x_eval, y_eval):
+ model.eval()
+ with torch.no_grad():
+ _, hiddens = model(x_eval, return_hidden=True)
+ h_cls_norms = [h[:, 0].norm(dim=-1).median().item() for h in hiddens]
+ # BP grads via manual forward
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ h_cls = model.out_ln(hs[-1][:, 0])
+ logits = model.out_head(h_cls)
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_cls_norms = [g[:, 0].norm(dim=-1).median().item() for g in grads]
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ model.train()
+ return {'hidden_norms_cls': h_cls_norms, 'bp_grad_per_sample_l2_med': g_cls_norms, 'acc_eval': acc}
+
+
+# ─── FA training ─────────────────────────────────────────────────────────
+
+def train_fa_resmlp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, no_outln=False):
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_params = list(model.out_head.parameters())
+ if hasattr(model, 'out_ln') and model.out_ln is not None:
+ head_params += list(model.out_ln.parameters())
+ head_opt = optim.AdamW(head_params, lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose_resmlp(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [FA] Ep 0: acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det)) if hasattr(model, 'out_ln') else model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ # FA credits
+ hL_req = hiddens[-1].detach().requires_grad_(True)
+ logits_fa = model.out_head(model.out_ln(hL_req)) if hasattr(model, 'out_ln') else model.out_head(hL_req)
+ loss_fa = F.cross_entropy(logits_fa, y, reduction='sum')
+ a_L = torch.autograd.grad(loss_fa, hL_req)[0].detach()
+ credits = [None] * L
+ credits[L-1] = a_L
+ for ll in range(L-2, -1, -1):
+ credits[ll] = (credits[ll+1] @ Bs[ll+1]).detach()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_l = credits[l]
+ rms = (a_l**2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a_l / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a_0 = credits[0]
+ rms_0 = (a_0**2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ d = diagnose_resmlp(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 10 == 0 or ep == 1 or ep == epochs:
+ print(f" [FA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} "
+ f"||g_L||={d['bp_grad_norms_per_sample_med'][-1]:.3e} "
+ f"acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def train_fa_vit(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ """Canonical FA for ViT: mean reduction, grad before step, no clipping, top-down."""
+ d_model = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d_model, d_model, device=device) / np.sqrt(d_model) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(
+ list(model.patch_embed.parameters()) + [model.cls_token, model.pos_embed],
+ lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose_vit(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [FA-vit] Ep 0: acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ # Head update — grad BEFORE step (old head)
+ hL_det = hiddens[-1].detach().requires_grad_(True)
+ h_cls = model.out_ln(hL_det[:, 0])
+ logits_out = model.out_head(h_cls)
+ loss_out = F.cross_entropy(logits_out, y) # mean reduction
+ head_opt.zero_grad()
+ loss_out.backward()
+ a_L_full = hL_det.grad.detach() # (B, n_tokens, d)
+ head_opt.step()
+ # Use mean over tokens for the backward signal
+ a_credit = a_L_full.mean(dim=1) # (B, d)
+ # Top-down block updates, propagate credit after each
+ for l in range(L - 1, -1, -1):
+ h_l = hiddens[l].detach()
+ a_broadcast = a_credit.unsqueeze(1).expand_as(h_l)
+ rms = (a_broadcast ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * (a_broadcast / rms)).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ block_opts[l].step() # no clipping
+ a_credit = (a_credit @ Bs[l]).detach()
+ # Embed update with final propagated credit
+ a_0_broadcast = a_credit.unsqueeze(1)
+ rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0_broadcast / rms_0.unsqueeze(1))).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ d = diagnose_vit(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep == 1 or ep == epochs:
+ print(f" [FA-vit] Ep {ep}: ||h_L||={d['hidden_norms_cls'][-1]:.3e} "
+ f"||g_L||={d['bp_grad_per_sample_l2_med'][-1]:.3e} "
+ f"acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--arch', choices=['vit', 'resmlp_noln'], required=True)
+ p.add_argument('--output', type=str, required=True)
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--seed', type=int, default=42)
+ args = p.parse_args()
+
+ device = torch.device('cuda:0')
+ train_loader, test_loader = get_cifar10(128)
+ x_eval, y_eval = fixed_eval_buffer(test_loader, device, 1024)
+
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+
+ if args.arch == 'vit':
+ # Match ViT snapshot params
+ model = ViTMini(d_model=128, n_heads=4, num_blocks=4, num_classes=10).to(device)
+ fa_log = train_fa_vit(model, train_loader, x_eval, y_eval, device,
+ args.epochs, lr=1e-3, wd=0.05)
+ else:
+ # ResMLP without terminal LN — use the same class as the original no-outln experiment
+ from experiments.snapshot_evolution_no_outln import ResidualMLP_NoOutLN
+ model = ResidualMLP_NoOutLN(3072, 256, 10, 4).to(device)
+ fa_log = train_fa_resmlp(model, train_loader, x_eval, y_eval, device,
+ args.epochs, lr=1e-3, wd=0.01, no_outln=True)
+
+ with open(args.output, 'w') as f:
+ json.dump({'fa_log': fa_log, 'arch': args.arch, 'seed': args.seed}, f, indent=2)
+ print(f"Saved: {args.output}", flush=True)
+
+
+if __name__ == '__main__':
+ main()