""" Snapshot evolution: per-epoch logging of residual-stream norms and BP-gradient norms during BP and DFA training of a 4-block d=256 ResMLP on CIFAR-10. Goal: confirm that ||h_l||_2 grows monotonically over epochs in DFA but stays bounded in BP, and that ||BP_grad||_2 collapses correspondingly. This generates the killer figure for the P4 (residual-stream pathology) finding in the NeurIPS 2026 FA Evaluation paper. Usage: CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_evolution_residual_explosion.py \ --output_dir results/snapshot_evolution_v2 > results/snapshot_evolution_v2.log 2>&1 & """ import os, sys, json, argparse, time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from metrics.credit_metrics import cosine_similarity_batch def get_cifar10(batch_size=128, num_workers=2): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tv_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train) te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=num_workers), DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=num_workers)) def fixed_eval_buffer(test_loader, device, n_samples=1024): xs, ys = [], [] for x, y in test_loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= n_samples: break x = torch.cat(xs)[:n_samples].to(device) y = torch.cat(ys)[:n_samples].to(device) return x, y def diagnose(model, x_eval, y_eval, dfa_Bs=None): """ Returns dict with: - hidden_norms: list of L+1 floats, median per-sample ||h_l||_2 on eval buffer - bp_grad_norms: list of L+1 floats, median per-sample ||g_l||_2 (BP grad) - bp_grad_norms_F: list of L+1 floats, ||g_l||_F per layer (Frobenius) - gamma_dfa: mean cosine over layers between DFA credit and BP grad (only if dfa_Bs given) - acc: test accuracy on the eval buffer - loss: mean CE on the eval buffer Critically: ALL norms use .norm(dim=-1), never .norm(-1). """ was_training = model.training model.eval() L = model.num_blocks C = 10 bs = x_eval.size(0) # Hidden states (no grad) with torch.no_grad(): _, hiddens = model(x_eval, return_hidden=True) hidden_norms = [h.norm(dim=-1).median().item() for h in hiddens] # BP gradients via manual graph, with x_eval as the input h0 = model.embed(x_eval.detach()) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) logits = model.out_head(model.out_ln(hs[-1])) loss = F.cross_entropy(logits, y_eval) grads = torch.autograd.grad(loss, hs) bp_grad_per_sample_l2 = [g.norm(dim=-1).median().item() for g in grads] bp_grad_F = [g.norm().item() for g in grads] bp_grad_full = [g.detach() for g in grads] acc = (logits.argmax(-1) == y_eval).float().mean().item() loss_val = loss.item() # DFA credit cosine to BP grad, if requested. # Convention (matches confirmatory_paper_experiments.compute_diagnostics_generic): # DFA's a_l represents the credit at the *input* to block l, which is h_l, so it # is compared against bp_grad_full[l] (gradient at h_l = input to block l). gamma_dfa = float('nan') if dfa_Bs is not None: with torch.no_grad(): e_T = logits.softmax(dim=-1) e_T[torch.arange(bs), y_eval] -= 1.0 cos_per_layer = [] for l in range(L): a_dfa = (e_T @ dfa_Bs[l].T).detach() cos_per_layer.append(cosine_similarity_batch(a_dfa, bp_grad_full[l])) gamma_dfa = float(np.mean(cos_per_layer)) if was_training: model.train() return { 'hidden_norms': hidden_norms, 'bp_grad_norms_per_sample_med': bp_grad_per_sample_l2, 'bp_grad_norms_F': bp_grad_F, 'gamma_dfa': gamma_dfa, 'acc_eval': acc, 'loss_eval': loss_val, } def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1): optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) log = [] # Epoch 0 (pre-training) d0 = diagnose(model, x_eval, y_eval) d0['epoch'] = 0 log.append(d0) print(f" [BP] Ep 0: ||h||_med={d0['hidden_norms']} ||g||_med={d0['bp_grad_norms_per_sample_med']} acc={d0['acc_eval']:.4f}", flush=True) for epoch in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) loss = F.cross_entropy(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() if epoch % log_every == 0 or epoch == epochs: d = diagnose(model, x_eval, y_eval) d['epoch'] = epoch log.append(d) print(f" [BP] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} " f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} " f"acc={d['acc_eval']:.4f}", flush=True) return log def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1, random_targets: bool = False): d_hidden = model.d_hidden L = model.num_blocks C = 10 Bs = [torch.randn(d_hidden, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = [] d0 = diagnose(model, x_eval, y_eval, dfa_Bs=Bs) d0['epoch'] = 0 log.append(d0) print(f" [DFA] Ep 0: ||h||_med={d0['hidden_norms']} ||g||_med={d0['bp_grad_norms_per_sample_med']} acc={d0['acc_eval']:.4f}", flush=True) for epoch in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) if random_targets: # iid random class targets refreshed every minibatch (codex round 34 sharper variant) y = torch.randint(0, 10, y.shape, device=device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad(); loss_out.backward(); head_opt.step() for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a_dfa / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad(); local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() a_0 = (e_T @ Bs[0].T).detach() rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() for s in all_sch: s.step() if epoch % log_every == 0 or epoch == epochs: d = diagnose(model, x_eval, y_eval, dfa_Bs=Bs) d['epoch'] = epoch log.append(d) print(f" [DFA] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} " f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} " f"acc={d['acc_eval']:.4f} gamma_dfa={d['gamma_dfa']:.4f}", flush=True) return log def train_fa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1): """FA (Lillicrap 2016): sequential backward credit with d×d random matrices. Canonical implementation matching cifar_resmlp.py train_fa(): - mean reduction (default) - gradient taken BEFORE head step (old head weights) - top-down block update, credit propagated after each block - NO grad clipping """ d_hidden = model.d_hidden L = model.num_blocks Bs_fa = [torch.randn(d_hidden, d_hidden, device=device) / np.sqrt(d_hidden) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = [] d0 = diagnose(model, x_eval, y_eval) d0['epoch'] = 0 log.append(d0) print(f" [FA] Ep 0: ||h||_med={d0['hidden_norms']} acc={d0['acc_eval']:.4f}", flush=True) for epoch in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) # Forward with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) # Head update — get gradient BEFORE step (old head weights) hL_det = hiddens[-1].detach().requires_grad_(True) logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) # mean reduction head_opt.zero_grad() loss_out.backward() a_credit = hL_det.grad.detach() # gradient w.r.t. old head head_opt.step() # Top-down block updates with sequential FA credit propagation for l in range(L - 1, -1, -1): h_l = hiddens[l].detach() rms = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a_credit / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() # no grad clipping a_credit = (a_credit @ Bs_fa[l]).detach() # Embed update with final propagated credit rms_0 = (a_credit ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_credit / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() for s in all_sch: s.step() if epoch % log_every == 0 or epoch == epochs: d = diagnose(model, x_eval, y_eval) d['epoch'] = epoch log.append(d) print(f" [FA] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} " f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} " f"acc={d['acc_eval']:.4f}", flush=True) return log def main(): p = argparse.ArgumentParser() p.add_argument('--output_dir', type=str, default='results/snapshot_evolution_v2') p.add_argument('--epochs', type=int, default=100) p.add_argument('--lr', type=float, default=1e-3) p.add_argument('--wd', type=float, default=0.01) p.add_argument('--seed', type=int, default=42) p.add_argument('--depth', type=int, default=4) p.add_argument('--d_hidden', type=int, default=256) p.add_argument('--log_every', type=int, default=1) p.add_argument('--no_residual_add', action='store_true', help='Replace h = h + f with h = f (non-residual stack of LN-W1-GELU-W2 blocks).') p.add_argument('--w2_std', type=float, default=0.01, help='Init std for w2 in each block. Bump to 0.05 for non-residual stack.') p.add_argument('--random_targets', action='store_true', help='Replace each minibatch label with iid random class targets (codex round 34 OPTION A).') p.add_argument('--skip_bp', action='store_true', help='Only train DFA, skip BP. Useful for cheap DFA-only ablations.') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device('cuda:0') # CUDA_VISIBLE_DEVICES selects which physical GPU print(f"device={device}, depth={args.depth}, d_hidden={args.d_hidden}, " f"epochs={args.epochs}, seed={args.seed}", flush=True) train_loader, test_loader = get_cifar10(batch_size=128) x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024) print(f"eval buffer: {x_eval.shape}", flush=True) L, d, C = args.depth, args.d_hidden, 10 bp_log = None if not args.skip_bp: print("\n=== BP training ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) bp_model = ResidualMLP(3072, d, C, L, residual_add=not args.no_residual_add, w2_std=args.w2_std).to(device) bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd, log_every=args.log_every) print("\n=== DFA training ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) dfa_model = ResidualMLP(3072, d, C, L, residual_add=not args.no_residual_add, w2_std=args.w2_std).to(device) dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd, log_every=args.log_every, random_targets=args.random_targets) fa_log = None if not args.skip_bp and not args.random_targets: # FA only when doing full run print("\n=== FA training ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) fa_model = ResidualMLP(3072, d, C, L, residual_add=not args.no_residual_add, w2_std=args.w2_std).to(device) fa_log = train_fa(fa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd, log_every=args.log_every) out = { 'config': vars(args), 'depth': L, 'd_hidden': d, 'num_classes': C, 'bp_log': bp_log, 'dfa_log': dfa_log, 'fa_log': fa_log, } out_path = os.path.join(args.output_dir, f'snapshot_evolution_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(out, f, indent=2) print(f"\nSaved {out_path}", flush=True) if __name__ == '__main__': main()