""" DFA canonical λ=1e-2 training + checkpoint save + fresh-B null calibration. Runs after the main penalty sweep to produce the null calibration on the canonical checkpoint. """ import os, sys, json, argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision, torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from metrics.credit_metrics import cosine_similarity_batch def get_data(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2)) def train_dfa_canonical(model, train_loader, device, epochs, lr, wd, penalty_lam): """Canonical DFA from cifar_resmlp.py: no grad clipping, mean reduction.""" d = model.d_hidden L = model.num_blocks C = 10 Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) for epoch in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad(); loss_out.backward(); head_opt.step() for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 f_l = model.blocks[l](h_l) local_loss = (f_l * (a_dfa / rms)).sum(dim=-1).mean() if penalty_lam > 0: local_loss = local_loss + penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() a_0 = (e_T @ Bs[0].T).detach() rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() for s in all_sch: s.step() if epoch % 10 == 0 or epoch == epochs: print(f" [DFA pen] ep {epoch}", flush=True) return Bs def compute_deep_cosine(model, Bs, x_eval, y_eval, device): """Compute per-layer DFA cosine on eval buffer.""" model.eval() L = model.num_blocks h0 = model.embed(x_eval.detach()) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) logits = model.out_head(model.out_ln(hs[-1])) loss = F.cross_entropy(logits, y_eval) grads = torch.autograd.grad(loss, hs) with torch.no_grad(): e_T = logits.softmax(-1) e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 cos_per_layer = [] for l in range(L): a_dfa = (e_T @ Bs[l].T).detach() cos_per_layer.append(cosine_similarity_batch(a_dfa, grads[l].detach())) acc = (logits.argmax(-1) == y_eval).float().mean().item() g_norms = [g.norm(dim=-1).median().item() for g in grads] h_norms = [h.detach().norm(dim=-1).median().item() for h in hs] return cos_per_layer, acc, g_norms, h_norms def main(): p = argparse.ArgumentParser() p.add_argument('--seed', type=int, default=42) p.add_argument('--output_dir', type=str, default='results/dfa_canonical_freshB') p.add_argument('--n_fresh', type=int, default=20) args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device('cuda:0') train_loader, test_loader = get_data(128) # Fixed eval buffer xs, ys = [], [] for x, y in test_loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= 128: break x_eval = torch.cat(xs)[:128].to(device) y_eval = torch.cat(ys)[:128].to(device) L, d, C = 4, 256, 10 # Train DFA with λ=1e-2 print(f"Training DFA canonical λ=0.01, seed={args.seed}", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) model = ResidualMLP(3072, d, C, L).to(device) training_Bs = train_dfa_canonical(model, train_loader, device, 30, 1e-3, 0.01, 0.01) # Save checkpoint ckpt_path = os.path.join(args.output_dir, f'dfa_canonical_lam0.01_s{args.seed}.pt') torch.save({'state_dict': model.state_dict(), 'Bs': [B.cpu() for B in training_Bs], 'seed': args.seed}, ckpt_path) print(f"Saved checkpoint: {ckpt_path}", flush=True) # Compute cosine with training Bs cos_training, acc, g_norms, h_norms = compute_deep_cosine(model, training_Bs, x_eval, y_eval, device) deep_cos_training = float(np.mean(cos_training[1:])) # exclude layer 0 print(f"Training-Bs: acc={acc:.4f}, deep cos={deep_cos_training:+.4f}") print(f" per-layer cos: {[f'{c:+.4f}' for c in cos_training]}") print(f" ||g_l||: {[f'{g:.2e}' for g in g_norms]}") print(f" ||h_l||: {[f'{h:.2e}' for h in h_norms]}") # Fresh-B null calibration print(f"\nFresh-B null calibration ({args.n_fresh} draws)...", flush=True) fresh_deep_cos = [] fresh_per_layer = [] for i in range(args.n_fresh): fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] cos_fresh, _, _, _ = compute_deep_cosine(model, fresh_Bs, x_eval, y_eval, device) deep_fresh = float(np.mean(cos_fresh[1:])) fresh_deep_cos.append(deep_fresh) fresh_per_layer.append(cos_fresh) fresh_mean = np.mean(fresh_deep_cos) fresh_std_ddof1 = np.std(fresh_deep_cos, ddof=1) print(f"Fresh-Bs deep cos: {fresh_mean:+.4f} ± {fresh_std_ddof1:.4f} (ddof=1)") # Save results out = { 'description': f'Canonical DFA λ=0.01 s={args.seed} + fresh-B null (N={args.n_fresh})', 'training_Bs_deep_cos': deep_cos_training, 'training_Bs_per_layer_cos': cos_training, 'training_Bs_acc': acc, 'training_Bs_g_norms': g_norms, 'training_Bs_h_norms': h_norms, 'fresh_Bs_n_draws': args.n_fresh, 'fresh_Bs_deep_cos_per_draw': fresh_deep_cos, 'fresh_Bs_deep_mean': fresh_mean, 'fresh_Bs_deep_std_ddof1': fresh_std_ddof1, 'fresh_Bs_per_layer_mean': [float(np.mean([fl[l] for fl in fresh_per_layer])) for l in range(L)], } out_path = os.path.join(args.output_dir, f'freshB_null_canonical_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(out, f, indent=2) print(f"Saved: {out_path}", flush=True) if __name__ == '__main__': main()