summaryrefslogtreecommitdiff
path: root/experiments/snapshot_synth_residual_explosion.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/snapshot_synth_residual_explosion.py')
-rw-r--r--experiments/snapshot_synth_residual_explosion.py195
1 files changed, 195 insertions, 0 deletions
diff --git a/experiments/snapshot_synth_residual_explosion.py b/experiments/snapshot_synth_residual_explosion.py
new file mode 100644
index 0000000..3470667
--- /dev/null
+++ b/experiments/snapshot_synth_residual_explosion.py
@@ -0,0 +1,195 @@
+"""
+Synthetic snapshot evolution: per-epoch logging of ||h_l||_2 and ||BP grad||_2
+on a teacher-student StudentNet (NO out_ln) trained with BP vs DFA.
+
+Goal: test whether the residual-stream explosion observed in CIFAR ResidualMLP
+(pre-LN with out_ln before head) also happens in the synthetic StudentNet
+architecture (no out_ln; head reads h_L directly). If synthetic does NOT show
+the explosion, then out_ln is causally responsible for the CIFAR pathology and
+the paper's P4 claim narrows to "pre-LN architectures with terminal LN".
+
+Usage:
+ CUDA_VISIBLE_DEVICES=2 nohup python experiments/snapshot_synth_residual_explosion.py \
+ --output_dir results/snapshot_synth_v1 --epochs 80 --alpha 1.0 --depth 4 --seed 42 \
+ > results/snapshot_synth_v1/run_a1.0_s42.log 2>&1 &
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from metrics.credit_metrics import cosine_similarity_batch
+# Import the StudentNet/TeacherNet/generate_synth_dataset directly from confirmatory script
+from experiments.confirmatory_paper_experiments import (
+ StudentNet, TeacherNet, generate_synth_dataset, set_seed
+)
+
+
+def diagnose_synth(model, x_eval, y_eval, dfa_Bs=None):
+ was_training = model.training
+ model.eval()
+ L = model.num_blocks
+
+ with torch.no_grad():
+ _, hi = model(x_eval, return_hidden=True)
+ hidden_norms = [h.norm(dim=-1).median().item() for h in hi]
+
+ # BP grads
+ h_list = [x_eval.detach().requires_grad_(True)]
+ for block in model.blocks:
+ h_list.append(h_list[-1] + block(h_list[-1]))
+ logits = model.out_head(h_list[-1])
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, h_list)
+ bp_grad_l2 = [g.norm(dim=-1).median().item() for g in grads]
+ bp_grad_F = [g.norm().item() for g in grads]
+ bp_full = [g.detach() for g in grads]
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ loss_val = loss.item()
+
+ gamma_dfa = float('nan')
+ per_layer_gamma = []
+ if dfa_Bs is not None:
+ with torch.no_grad():
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1.0
+ for l in range(L):
+ a_dfa = (e_T @ dfa_Bs[l].T).detach()
+ per_layer_gamma.append(cosine_similarity_batch(a_dfa, bp_full[l]))
+ gamma_dfa = float(np.mean(per_layer_gamma))
+
+ if was_training:
+ model.train()
+ return {
+ 'hidden_norms': hidden_norms,
+ 'bp_grad_per_sample_l2_med': bp_grad_l2,
+ 'bp_grad_F': bp_grad_F,
+ 'gamma_dfa': gamma_dfa,
+ 'gamma_dfa_per_layer': per_layer_gamma,
+ 'acc_eval': acc,
+ 'loss_eval': loss_val,
+ }
+
+
+def train_bp(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ log = []
+ d0 = diagnose_synth(model, x_eval, y_eval); d0['epoch'] = 0; log.append(d0)
+ print(f" [BP] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ d = diagnose_synth(model, x_eval, y_eval); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep in (1, epochs):
+ print(f" [BP] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f}", flush=True)
+ return log
+
+
+def train_dfa(model, train_loader, x_eval, y_eval, device, epochs, lr, wd):
+ d_hidden = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ Bs = [torch.randn(d_hidden, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+ log = []
+ d0 = diagnose_synth(model, x_eval, y_eval, dfa_Bs=Bs); d0['epoch'] = 0; log.append(d0)
+ print(f" [DFA] Ep 0: ||h_L||={d0['hidden_norms'][-1]:.3e} ||g||={d0['bp_grad_per_sample_l2_med'][2]:.3e} acc={d0['acc_eval']:.4f}", flush=True)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ hL_det = hiddens[-1].detach()
+ # head update via direct CE on head(hL)
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ # block updates via DFA local credit
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ for s in all_sch:
+ s.step()
+ d = diagnose_synth(model, x_eval, y_eval, dfa_Bs=Bs); d['epoch'] = ep; log.append(d)
+ if ep % 5 == 0 or ep in (1, epochs):
+ print(f" [DFA] Ep {ep}: ||h_L||={d['hidden_norms'][-1]:.3e} ||g||={d['bp_grad_per_sample_l2_med'][2]:.3e} acc={d['acc_eval']:.4f} γ_dfa={d['gamma_dfa']:.4f}", flush=True)
+ return log
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--output_dir', type=str, default='results/snapshot_synth_v1')
+ p.add_argument('--epochs', type=int, default=80)
+ p.add_argument('--alpha', type=float, default=1.0)
+ p.add_argument('--depth', type=int, default=4)
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--d_hidden', type=int, default=128)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device('cuda:0')
+ print(f"device={device}, alpha={args.alpha}, depth={args.depth}, "
+ f"d_hidden={args.d_hidden}, epochs={args.epochs}, seed={args.seed}", flush=True)
+
+ set_seed(args.seed)
+ L, d, C = args.depth, args.d_hidden, 10
+ teacher = TeacherNet(d, L, C, args.alpha, seed=0).to(device)
+
+ n_train = 50 * 256
+ n_test = 2000
+ X_tr, Y_tr = generate_synth_dataset(teacher, n_train, d, device, seed=args.seed)
+ X_te, Y_te = generate_synth_dataset(teacher, n_test, d, device, seed=args.seed + 10000)
+ train_loader = DataLoader(TensorDataset(X_tr, Y_tr), batch_size=256, shuffle=True)
+ x_eval, y_eval = X_te.to(device), Y_te.to(device)
+ print(f"train: {X_tr.shape}, test eval buffer: {x_eval.shape}", flush=True)
+
+ print("\n=== BP training ===", flush=True)
+ set_seed(args.seed)
+ bp_model = StudentNet(d, C, L, args.alpha).to(device)
+ bp_log = train_bp(bp_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd)
+
+ print("\n=== DFA training ===", flush=True)
+ set_seed(args.seed)
+ dfa_model = StudentNet(d, C, L, args.alpha).to(device)
+ dfa_log = train_dfa(dfa_model, train_loader, x_eval, y_eval, device, args.epochs, args.lr, args.wd)
+
+ out = {
+ 'config': vars(args),
+ 'depth': L, 'd_hidden': d, 'num_classes': C,
+ 'bp_log': bp_log,
+ 'dfa_log': dfa_log,
+ }
+ out_path = os.path.join(args.output_dir, f'snapshot_synth_a{args.alpha}_L{L}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(out, f, indent=2)
+ print(f"\nSaved {out_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()