""" Canonical DFA penalty trajectory: per-epoch ||h_L|| and ||g_L|| for λ ∈ {0, 1e-4, 1e-2}. 3 seeds × 3 λ × 30 epochs. Uses canonical cifar_resmlp.py DFA implementation (no clipping, mean reduction). """ import os, sys, json, argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision, torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP def get_data(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) def diagnose_quick(model, x_eval, y_eval): model.eval() x_flat = x_eval.view(x_eval.size(0), -1) with torch.no_grad(): logits, hiddens = model(x_flat, return_hidden=True) h_L = hiddens[-1].norm(dim=-1).median().item() # BP grad at h_L h0 = model.embed(x_flat.detach()) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) logits2 = model.out_head(model.out_ln(hs[-1])) loss = F.cross_entropy(logits2, y_eval) grads = torch.autograd.grad(loss, hs) g_L = grads[-1].norm(dim=-1).median().item() acc = (logits.argmax(-1) == y_eval).float().mean().item() model.train() return h_L, g_L, acc def train_dfa_trajectory(seed, train_loader, x_eval, y_eval, device, epochs, lam): L, d, C = 4, 256, 10 torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) model = ResidualMLP(3072, d, C, L).to(device) Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=1e-3, weight_decay=0.01) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=1e-3, weight_decay=0.01) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = [] h_L, g_L, acc = diagnose_quick(model, x_eval, y_eval) log.append({'epoch': 0, 'h_L': h_L, 'g_L': g_L, 'acc': acc}) for epoch in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL_det)) head_opt.zero_grad(); F.cross_entropy(logits_out, y).backward(); head_opt.step() for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 f_l = model.blocks[l](h_l) local_loss = (f_l * (a_dfa / rms)).sum(dim=-1).mean() if lam > 0: local_loss = local_loss + lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() a_0 = (e_T @ Bs[0].T).detach() rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() for s in all_sch: s.step() h_L, g_L, acc = diagnose_quick(model, x_eval, y_eval) log.append({'epoch': epoch, 'h_L': h_L, 'g_L': g_L, 'acc': acc}) if epoch % 10 == 0 or epoch == epochs: print(f" [lam={lam}] s={seed} ep {epoch}: ||h_L||={h_L:.3e} ||g_L||={g_L:.3e} acc={acc:.4f}", flush=True) return log def main(): p = argparse.ArgumentParser() p.add_argument('--output', type=str, default='results/dfa_canonical_penalty_trajectory.json') args = p.parse_args() device = torch.device('cuda:0') train_loader, test_loader = get_data(128) # Fixed 128-sample eval buffer (consistent with cifar_resmlp.py compute_diagnostics) xs, ys = [], [] for x, y in test_loader: xs.append(x); ys.append(y) if sum(xb.size(0) for xb in xs) >= 128: break x_eval = torch.cat(xs)[:128].to(device) y_eval = torch.cat(ys)[:128].to(device) results = {} for lam in [0.0, 1e-4, 1e-2]: lam_key = f'lam_{lam}' results[lam_key] = {} for seed in [42, 123, 456]: print(f"\n=== λ={lam}, seed={seed} ===", flush=True) log = train_dfa_trajectory(seed, train_loader, x_eval, y_eval, device, 30, lam) results[lam_key][str(seed)] = log with open(args.output, 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved: {args.output}", flush=True) if __name__ == '__main__': main()