""" Train vanilla DFA (no penalty) on the standard 4-block d=256 ResMLP and save checkpoints at the early epochs (1, 2, 3) BEFORE ‖g_L‖ has collapsed to the numerical floor. Codex round 19's #3 priority experiment to disambiguate: - Hypothesis A: deep-layer alignment was always present in vanilla DFA but hidden by the post-collapse measurement degeneracy. Penalty just made the measurement interpretable. - Hypothesis B: deep-layer alignment was created by the penalty intervention. Vanilla DFA at any epoch has zero deep alignment. Test: measure deep-layer cos at vanilla checkpoints from ep 1, 2, 3 (when ‖g_L‖ should still be in the meaningful regime). Run: CUDA_VISIBLE_DEVICES=2 python experiments/vanilla_dfa_early_ckpt.py --seed 42 """ import os import sys import argparse import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP def get_loaders(batch_size=128): tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return ( DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=2), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=2), ) def evaluate(model, loader, dev): model.eval() n = c = 0 with torch.no_grad(): for x, y in loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) preds = model(x).argmax(-1) c += (preds == y).sum().item() n += x.size(0) return c / n def diagnose_norms(model, x_eval, y_eval, dev): model.eval() with torch.no_grad(): _, hi = model(x_eval, return_hidden=True) h_norms = [h.norm(dim=-1).median().item() for h in hi] h0 = model.embed(x_eval.detach()) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) lo = model.out_head(model.out_ln(hs[-1])) loss = F.cross_entropy(lo, y_eval) gs = torch.autograd.grad(loss, hs) g_norms = [g.norm(dim=-1).median().item() for g in gs] return h_norms, g_norms def train_vanilla_dfa(model, train_loader, dev, max_epoch, lr, wd, Bs, x_eval, y_eval, save_at, output_dir, seed): L = model.num_blocks block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd ) log = [] h0_norms, g0_norms = diagnose_norms(model, x_eval, y_eval, dev) log.append({"epoch": 0, "h_norms": h0_norms, "g_norms": g0_norms}) print(f" ep 0: h_norms={[f'{h:.2e}' for h in h0_norms]}, g_norms={[f'{g:.2e}' for g in g0_norms]}", flush=True) for ep in range(1, max_epoch + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(dev); y = y.to(dev) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() head_opt.zero_grad() F.cross_entropy(model.out_head(model.out_ln(hL_det)), y).backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a = (e_T @ Bs[l].T).detach() rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 f = model.blocks[l](h_l) loss = (f * (a / rms)).sum(-1).mean() block_opts[l].zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() a0 = (e_T @ Bs[0].T).detach() rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 h0_emb = model.embed(x) embed_opt.zero_grad() (h0_emb * (a0 / rms0)).sum(-1).mean().backward() embed_opt.step() h_norms, g_norms = diagnose_norms(model, x_eval, y_eval, dev) log.append({"epoch": ep, "h_norms": h_norms, "g_norms": g_norms}) print(f" ep {ep}: h_norms={[f'{h:.2e}' for h in h_norms]}, g_norms={[f'{g:.2e}' for g in g_norms]}", flush=True) if ep in save_at: ckpt_path = os.path.join(output_dir, f"vanilla_dfa_s{seed}_ep{ep}.pt") torch.save({ "state_dict": model.state_dict(), "Bs": [b.cpu() for b in Bs], "epoch": ep, "h_norms": h_norms, "g_norms": g_norms, }, ckpt_path) print(f" saved {ckpt_path}", flush=True) return log def main(): p = argparse.ArgumentParser() p.add_argument("--seed", type=int, default=42) p.add_argument("--max_epoch", type=int, default=5) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--wd", type=float, default=0.01) p.add_argument("--save_at", type=int, nargs="+", default=[1, 2, 3, 4, 5]) p.add_argument("--output_dir", type=str, default="results/vanilla_dfa_early_ckpts") args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) dev = torch.device("cuda:0") print(f"Vanilla DFA early-epoch checkpoint sweep: seed={args.seed}, max_epoch={args.max_epoch}", flush=True) train_loader, test_loader = get_loaders(batch_size=128) # Eval batch xs, ys = [], [] for x, y in test_loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= 1024: break x_eval = torch.cat(xs)[:1024].to(dev) y_eval = torch.cat(ys)[:1024].to(dev) L, d, C = 4, 256, 10 torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) m = ResidualMLP(3072, d, C, L).to(dev) Bs = [torch.randn(d, C, device=dev) / np.sqrt(C) for _ in range(L)] log = train_vanilla_dfa(m, train_loader, dev, args.max_epoch, args.lr, args.wd, Bs, x_eval, y_eval, args.save_at, args.output_dir, args.seed) out = {"config": vars(args), "log": log} out_path = os.path.join(args.output_dir, f"vanilla_dfa_s{args.seed}_log.json") with open(out_path, "w") as f: json.dump(out, f, indent=2) print(f"Saved {out_path}") if __name__ == "__main__": main()