diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-05-04 19:50:45 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-05-04 19:50:45 -0500 |
| commit | b480d0cdc21f944e4adccf6e81cc939b0450c5e9 (patch) | |
| tree | f0e6afb5b3d448d1d6c35d9622d22d63073ca9a7 /reproduce/penalty_sweep.py | |
Initial submission code: FA evaluation protocol + reproduction scripts
Reference implementation of the three-diagnostic FA evaluation protocol
(scale stability, reference validity, depth utility) from the NeurIPS 2026
E&D track paper. Includes models, metrics, and full reproduction pipeline.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'reproduce/penalty_sweep.py')
| -rw-r--r-- | reproduce/penalty_sweep.py | 176 |
1 files changed, 176 insertions, 0 deletions
diff --git a/reproduce/penalty_sweep.py b/reproduce/penalty_sweep.py new file mode 100644 index 0000000..b6b913d --- /dev/null +++ b/reproduce/penalty_sweep.py @@ -0,0 +1,176 @@ +""" +Penalty intervention sweep: DFA + lambda x {0, 1e-4, 1e-2} with per-epoch trajectory. +Includes fresh-B null calibration on the lambda=1e-2 checkpoint. + +Usage: + python reproduce/penalty_sweep.py --seeds 42 123 456 --gpu 0 +""" +import os, sys, json, argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from reproduce.train_methods import get_data, evaluate, make_model, _pool_hidden, _get_head_logits +from metrics.credit_metrics import cosine_similarity_batch + + +def train_dfa_trajectory(seed, train_loader, test_loader, device, epochs, lam, num_classes=10): + """DFA with per-epoch ||h_L||, ||g_L|| logging.""" + torch.manual_seed(seed); np.random.seed(seed) + if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + from models.residual_mlp import ResidualMLP + model = ResidualMLP(3072, 256, num_classes, 4).to(device) + d, L, C = 256, 4, num_classes + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=1e-3, weight_decay=0.01) + all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + # Eval buffer + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + + def diagnose(): + model.eval() + with torch.no_grad(): + _, hi = model(x_eval, return_hidden=True) + h_L = hi[-1].norm(dim=-1).median().item() + h0 = model.embed(x_eval) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + logits = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + g_L = grads[-1].norm(dim=-1).median().item() + acc = (logits.argmax(-1) == y_eval).float().mean().item() + model.train() + return h_L, g_L, acc + + log = [] + h, g, a = diagnose() + log.append({'epoch': 0, 'h_L': h, 'g_L': g, 'acc': a}) + + for ep in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + hL = hiddens[-1].detach() + head_opt.zero_grad() + F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() + head_opt.step() + for l in range(L): + a_dfa = (e_T @ Bs[l].T).detach() + rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + f_l = model.blocks[l](hiddens[l].detach()) + local_loss = (f_l * (a_dfa / rms)).sum(-1).mean() + if lam > 0: + local_loss = local_loss + lam * (f_l ** 2).sum(-1).mean() + block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + h0 = model.embed(x) + embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step() + for s in all_sch: s.step() + h, g, a = diagnose() + log.append({'epoch': ep, 'h_L': h, 'g_L': g, 'acc': a}) + if ep % 10 == 0 or ep == epochs: + print(f" [lam={lam}] s={seed} ep {ep}: ||h_L||={h:.3e} ||g_L||={g:.3e} acc={a:.4f}", flush=True) + + return log, model, Bs + + +def fresh_b_null(model, x_eval, y_eval, training_Bs, n_draws=20): + """Fresh-B null calibration on a trained checkpoint.""" + model.eval() + d, L, C = 256, 4, len(training_Bs[0][0]) if training_Bs[0].dim() == 2 else 10 + device = x_eval.device + + def deep_cos_with_Bs(Bs): + h0 = model.embed(x_eval) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: + hs.append(hs[-1] + b(hs[-1])) + logits = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(logits, y_eval) + grads = torch.autograd.grad(loss, hs) + with torch.no_grad(): + e_T = logits.softmax(-1) + e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 + cos_layers = [] + for l in range(L): + a = (e_T @ Bs[l].T).detach() + cos_layers.append(cosine_similarity_batch(a, grads[l].detach())) + return float(np.mean(cos_layers[1:])) # deep = exclude layer 0 + + train_cos = deep_cos_with_Bs(training_Bs) + fresh_cos = [] + for _ in range(n_draws): + fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + fresh_cos.append(deep_cos_with_Bs(fresh_Bs)) + + return { + 'training_Bs_deep_cos': train_cos, + 'fresh_Bs_deep_mean': float(np.mean(fresh_cos)), + 'fresh_Bs_deep_std_ddof1': float(np.std(fresh_cos, ddof=1)), + 'n_draws': n_draws, + } + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) + p.add_argument('--epochs', type=int, default=30) + p.add_argument('--lambdas', nargs='+', type=float, default=[0.0, 1e-4, 1e-2]) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/penalty_sweep') + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + train_loader, test_loader, _ = get_data('cifar10', 128) + + results = {} + for lam in args.lambdas: + lam_key = f'lam_{lam}' + results[lam_key] = {} + for seed in args.seeds: + print(f"\n=== lambda={lam}, seed={seed} ===", flush=True) + log, model, Bs = train_dfa_trajectory(seed, train_loader, test_loader, device, args.epochs, lam) + results[lam_key][str(seed)] = log + + # Fresh-B null on lambda=1e-2, seed=42 only + if lam == 1e-2 and seed == 42: + xs, ys = [], [] + for x, y in test_loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= 128: break + x_eval = torch.cat(xs)[:128].to(device) + y_eval = torch.cat(ys)[:128].to(device) + null = fresh_b_null(model, x_eval, y_eval, Bs) + results['fresh_b_null'] = null + print(f" Fresh-B: training={null['training_Bs_deep_cos']:+.4f}, " + f"fresh={null['fresh_Bs_deep_mean']:+.4f} +/- {null['fresh_Bs_deep_std_ddof1']:.4f}") + + with open(os.path.join(args.output_dir, 'penalty_sweep.json'), 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved: {args.output_dir}/penalty_sweep.json") + + +if __name__ == '__main__': + main() |
