summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 05:39:39 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 05:39:39 -0500
commit8dd65b2ec3df32749adabbf62c55101d5b00ae7b (patch)
tree3a329bfdf9867ae13889dfcecd65ef216734947b /experiments
parent68cfa13af2f026b7ff388aae4420eba0f0db804a (diff)
Round 32+33 H2 ablation: add no_residual_add flag; falsify residual-as-cause hypothesis
- models/residual_mlp.py: add residual_add and w2_std flags (default unchanged) - experiments/snapshot_evolution_residual_explosion.py: add --no_residual_add and --w2_std CLI flags - paper/main.tex §3 ¶3: add 1-sentence reference to no-residual control showing Mode 1 still fires - paper/main.tex Appendix I: full smoke-test table + interpretation - v2.2 main content stays at 8 pages (within 9-page E&D budget); 13 pages total Smoke test (3 ep, w2_std=0.5, seed 42): - DFA no-residual: ||h_L|| 4.69 -> 22050, ||g|| 1.6e-7 (Mode 1 (a) fires; (b) at floor) - BP no-residual: acc only 0.16 at ep 3 (architecture is partially degenerate) - Conclusion: residual skip is NOT necessary for Mode 1; the proximate trigger is more general - Codex round 33 verdict: WALK BACK H2; demote 100ep run to confirmatory Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/snapshot_evolution_residual_explosion.py267
1 files changed, 267 insertions, 0 deletions
diff --git a/experiments/snapshot_evolution_residual_explosion.py b/experiments/snapshot_evolution_residual_explosion.py
new file mode 100644
index 0000000..86de4a4
--- /dev/null
+++ b/experiments/snapshot_evolution_residual_explosion.py
@@ -0,0 +1,267 @@
+"""
+Snapshot evolution: per-epoch logging of residual-stream norms and BP-gradient norms
+during BP and DFA training of a 4-block d=256 ResMLP on CIFAR-10.
+
+Goal: confirm that ||h_l||_2 grows monotonically over epochs in DFA but stays
+bounded in BP, and that ||BP_grad||_2 collapses correspondingly. This generates
+the killer figure for the P4 (residual-stream pathology) finding in the
+NeurIPS 2026 FA Evaluation paper.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_evolution_residual_explosion.py \
+ --output_dir results/snapshot_evolution_v2 > results/snapshot_evolution_v2.log 2>&1 &
+"""
+import os, sys, json, argparse, time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from metrics.credit_metrics import cosine_similarity_batch
+
+
+def get_cifar10(batch_size=128, num_workers=2):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tr = torchvision.datasets.CIFAR10('./data', True, download=True, transform=tv_train)
+ te = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return (DataLoader(tr, batch_size=batch_size, shuffle=True, num_workers=num_workers),
+ DataLoader(te, batch_size=batch_size, shuffle=False, num_workers=num_workers))
+
+
+def fixed_eval_buffer(test_loader, device, n_samples=1024):
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= n_samples:
+ break
+ x = torch.cat(xs)[:n_samples].to(device)
+ y = torch.cat(ys)[:n_samples].to(device)
+ return x, y
+
+
+def diagnose(model, x_eval, y_eval, dfa_Bs=None):
+ """
+ Returns dict with:
+ - hidden_norms: list of L+1 floats, median per-sample ||h_l||_2 on eval buffer
+ - bp_grad_norms: list of L+1 floats, median per-sample ||g_l||_2 (BP grad)
+ - bp_grad_norms_F: list of L+1 floats, ||g_l||_F per layer (Frobenius)
+ - gamma_dfa: mean cosine over layers between DFA credit and BP grad (only if dfa_Bs given)
+ - acc: test accuracy on the eval buffer
+ - loss: mean CE on the eval buffer
+ Critically: ALL norms use .norm(dim=-1), never .norm(-1).
+ """
+ was_training = model.training
+ model.eval()
+ L = model.num_blocks
+ C = 10
+ bs = x_eval.size(0)
+
+ # Hidden states (no grad)
+ with torch.no_grad():
+ _, hiddens = model(x_eval, return_hidden=True)
+ hidden_norms = [h.norm(dim=-1).median().item() for h in hiddens]
+
+ # BP gradients via manual graph, with x_eval as the input
+ h0 = model.embed(x_eval.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ bp_grad_per_sample_l2 = [g.norm(dim=-1).median().item() for g in grads]
+ bp_grad_F = [g.norm().item() for g in grads]
+ bp_grad_full = [g.detach() for g in grads]
+
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ loss_val = loss.item()
+
+ # DFA credit cosine to BP grad, if requested.
+ # Convention (matches confirmatory_paper_experiments.compute_diagnostics_generic):
+ # DFA's a_l represents the credit at the *input* to block l, which is h_l, so it
+ # is compared against bp_grad_full[l] (gradient at h_l = input to block l).
+ gamma_dfa = float('nan')
+ if dfa_Bs is not None:
+ with torch.no_grad():
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(bs), y_eval] -= 1.0
+ cos_per_layer = []
+ for l in range(L):
+ a_dfa = (e_T @ dfa_Bs[l].T).detach()
+ cos_per_layer.append(cosine_similarity_batch(a_dfa, bp_grad_full[l]))
+ gamma_dfa = float(np.mean(cos_per_layer))
+
+ if was_training:
+ model.train()
+
+ return {
+ 'hidden_norms': hidden_norms,
+ 'bp_grad_norms_per_sample_med': bp_grad_per_sample_l2,
+ 'bp_grad_norms_F': bp_grad_F,
+ 'gamma_dfa': gamma_dfa,
+ 'acc_eval': acc,
+ 'loss_eval': loss_val,
+ }
+
+
+def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1):
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ log = []
+ # Epoch 0 (pre-training)
+ d0 = diagnose(model, x_eval, y_eval)
+ d0['epoch'] = 0
+ log.append(d0)
+ print(f" [BP] Ep 0: ||h||_med={d0['hidden_norms']} ||g||_med={d0['bp_grad_norms_per_sample_med']} acc={d0['acc_eval']:.4f}", flush=True)
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+ if epoch % log_every == 0 or epoch == epochs:
+ d = diagnose(model, x_eval, y_eval)
+ d['epoch'] = epoch
+ log.append(d)
+ print(f" [BP] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} "
+ f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} "
+ f"acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd, log_every=1):
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d_hidden, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ all_sch = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)])
+ log = []
+ d0 = diagnose(model, x_eval, y_eval, dfa_Bs=Bs)
+ d0['epoch'] = 0
+ log.append(d0)
+ print(f" [DFA] Ep 0: ||h||_med={d0['hidden_norms']} ||g||_med={d0['bp_grad_norms_per_sample_med']} acc={d0['acc_eval']:.4f}", flush=True)
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a_0 = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+ for s in all_sch:
+ s.step()
+ if epoch % log_every == 0 or epoch == epochs:
+ d = diagnose(model, x_eval, y_eval, dfa_Bs=Bs)
+ d['epoch'] = epoch
+ log.append(d)
+ print(f" [DFA] Ep {epoch}: ||h_L||={d['hidden_norms'][-1]:.3e} "
+ f"||g_2||={d['bp_grad_norms_per_sample_med'][2]:.3e} "
+ f"acc={d['acc_eval']:.4f} gamma_dfa={d['gamma_dfa']:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output_dir', type=str, default='results/snapshot_evolution_v2')
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--depth', type=int, default=4)
+ p.add_argument('--d_hidden', type=int, default=256)
+ p.add_argument('--log_every', type=int, default=1)
+ p.add_argument('--no_residual_add', action='store_true',
+ help='Replace h = h + f with h = f (non-residual stack of LN-W1-GELU-W2 blocks).')
+ p.add_argument('--w2_std', type=float, default=0.01,
+ help='Init std for w2 in each block. Bump to 0.05 for non-residual stack.')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device('cuda:0') # CUDA_VISIBLE_DEVICES selects which physical GPU
+ print(f"device={device}, depth={args.depth}, d_hidden={args.d_hidden}, "
+ f"epochs={args.epochs}, seed={args.seed}", flush=True)
+
+ train_loader, test_loader = get_cifar10(batch_size=128)
+ x_eval, y_eval = fixed_eval_buffer(test_loader, device, n_samples=1024)
+ print(f"eval buffer: {x_eval.shape}", flush=True)
+
+ L, d, C = args.depth, args.d_hidden, 10
+
+ print("\n=== BP training ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ bp_model = ResidualMLP(3072, d, C, L,
+ residual_add=not args.no_residual_add,
+ w2_std=args.w2_std).to(device)
+ bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device,
+ args.epochs, args.lr, args.wd, log_every=args.log_every)
+
+ print("\n=== DFA training ===", flush=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ dfa_model = ResidualMLP(3072, d, C, L,
+ residual_add=not args.no_residual_add,
+ w2_std=args.w2_std).to(device)
+ dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device,
+ args.epochs, args.lr, args.wd, log_every=args.log_every)
+
+ out = {
+ 'config': vars(args),
+ 'depth': L, 'd_hidden': d, 'num_classes': C,
+ 'bp_log': bp_log,
+ 'dfa_log': dfa_log,
+ }
+ out_path = os.path.join(args.output_dir, f'snapshot_evolution_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+ print(f"\nSaved {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()