summaryrefslogtreecommitdiff
path: root/reproduce/penalty_sweep.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-05-04 19:50:45 -0500
commitb480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch)
treef0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /reproduce/penalty_sweep.py
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol (scale stability, reference validity, depth utility) from the NeurIPS 2026 E&D track paper. Includes models, metrics, and full reproduction pipeline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'reproduce/penalty_sweep.py')
-rw-r--r--reproduce/penalty_sweep.py176
1 files changed, 176 insertions, 0 deletions
diff --git a/reproduce/penalty_sweep.py b/reproduce/penalty_sweep.py
new file mode 100644
index 0000000..b6b913d
--- /dev/null
+++ b/reproduce/penalty_sweep.py
@@ -0,0 +1,176 @@
+"""
+Penalty intervention sweep: DFA + lambda x {0, 1e-4, 1e-2} with per-epoch trajectory.
+Includes fresh-B null calibration on the lambda=1e-2 checkpoint.
+
+Usage:
+ python reproduce/penalty_sweep.py --seeds 42 123 456 --gpu 0
+"""
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from reproduce.train_methods import get_data, evaluate, make_model, _pool_hidden, _get_head_logits
+from metrics.credit_metrics import cosine_similarity_batch
+
+
+def train_dfa_trajectory(seed, train_loader, test_loader, device, epochs, lam, num_classes=10):
+ """DFA with per-epoch ||h_L||, ||g_L|| logging."""
+ torch.manual_seed(seed); np.random.seed(seed)
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
+ from models.residual_mlp import ResidualMLP
+ model = ResidualMLP(3072, 256, num_classes, 4).to(device)
+ d, L, C = 256, 4, num_classes
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=1e-3, weight_decay=0.01)
+ all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+
+ # Eval buffer
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128: break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+
+ def diagnose():
+ model.eval()
+ with torch.no_grad():
+ _, hi = model(x_eval, return_hidden=True)
+ h_L = hi[-1].norm(dim=-1).median().item()
+ h0 = model.embed(x_eval)
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ g_L = grads[-1].norm(dim=-1).median().item()
+ acc = (logits.argmax(-1) == y_eval).float().mean().item()
+ model.train()
+ return h_L, g_L, acc
+
+ log = []
+ h, g, a = diagnose()
+ log.append({'epoch': 0, 'h_L': h, 'g_L': g, 'acc': a})
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+ hL = hiddens[-1].detach()
+ head_opt.zero_grad()
+ F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward()
+ head_opt.step()
+ for l in range(L):
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f_l = model.blocks[l](hiddens[l].detach())
+ local_loss = (f_l * (a_dfa / rms)).sum(-1).mean()
+ if lam > 0:
+ local_loss = local_loss + lam * (f_l ** 2).sum(-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step()
+ for s in all_sch: s.step()
+ h, g, a = diagnose()
+ log.append({'epoch': ep, 'h_L': h, 'g_L': g, 'acc': a})
+ if ep % 10 == 0 or ep == epochs:
+ print(f" [lam={lam}] s={seed} ep {ep}: ||h_L||={h:.3e} ||g_L||={g:.3e} acc={a:.4f}", flush=True)
+
+ return log, model, Bs
+
+
+def fresh_b_null(model, x_eval, y_eval, training_Bs, n_draws=20):
+ """Fresh-B null calibration on a trained checkpoint."""
+ model.eval()
+ d, L, C = 256, 4, len(training_Bs[0][0]) if training_Bs[0].dim() == 2 else 10
+ device = x_eval.device
+
+ def deep_cos_with_Bs(Bs):
+ h0 = model.embed(x_eval)
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ logits = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(logits, y_eval)
+ grads = torch.autograd.grad(loss, hs)
+ with torch.no_grad():
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(x_eval.size(0)), y_eval] -= 1
+ cos_layers = []
+ for l in range(L):
+ a = (e_T @ Bs[l].T).detach()
+ cos_layers.append(cosine_similarity_batch(a, grads[l].detach()))
+ return float(np.mean(cos_layers[1:])) # deep = exclude layer 0
+
+ train_cos = deep_cos_with_Bs(training_Bs)
+ fresh_cos = []
+ for _ in range(n_draws):
+ fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ fresh_cos.append(deep_cos_with_Bs(fresh_Bs))
+
+ return {
+ 'training_Bs_deep_cos': train_cos,
+ 'fresh_Bs_deep_mean': float(np.mean(fresh_cos)),
+ 'fresh_Bs_deep_std_ddof1': float(np.std(fresh_cos, ddof=1)),
+ 'n_draws': n_draws,
+ }
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456])
+ p.add_argument('--epochs', type=int, default=30)
+ p.add_argument('--lambdas', nargs='+', type=float, default=[0.0, 1e-4, 1e-2])
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/penalty_sweep')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ train_loader, test_loader, _ = get_data('cifar10', 128)
+
+ results = {}
+ for lam in args.lambdas:
+ lam_key = f'lam_{lam}'
+ results[lam_key] = {}
+ for seed in args.seeds:
+ print(f"\n=== lambda={lam}, seed={seed} ===", flush=True)
+ log, model, Bs = train_dfa_trajectory(seed, train_loader, test_loader, device, args.epochs, lam)
+ results[lam_key][str(seed)] = log
+
+ # Fresh-B null on lambda=1e-2, seed=42 only
+ if lam == 1e-2 and seed == 42:
+ xs, ys = [], []
+ for x, y in test_loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= 128: break
+ x_eval = torch.cat(xs)[:128].to(device)
+ y_eval = torch.cat(ys)[:128].to(device)
+ null = fresh_b_null(model, x_eval, y_eval, Bs)
+ results['fresh_b_null'] = null
+ print(f" Fresh-B: training={null['training_Bs_deep_cos']:+.4f}, "
+ f"fresh={null['fresh_Bs_deep_mean']:+.4f} +/- {null['fresh_Bs_deep_std_ddof1']:.4f}")
+
+ with open(os.path.join(args.output_dir, 'penalty_sweep.json'), 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved: {args.output_dir}/penalty_sweep.json")
+
+
+if __name__ == '__main__':
+ main()