""" Penalty intervention sweep: DFA + lambda x {0, 1e-4, 1e-2} with per-epoch trajectory. Includes fresh-B null calibration on the lambda=1e-2 checkpoint. Usage: python reproduce/penalty_sweep.py --seeds 42 123 456 --gpu 0 """ import os, sys, json, argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from reproduce.train_methods import get_data, evaluate, make_model, _pool_hidden, _get_head_logits from metrics.credit_metrics import cosine_similarity_batch def train_dfa_trajectory(seed, train_loader, test_loader, device, epochs, lam, num_classes=10): """DFA with per-epoch ||h_L||, ||g_L|| logging.""" torch.manual_seed(seed); np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) from models.residual_mlp import ResidualMLP model = ResidualMLP(3072, 256, num_classes, 4).to(device) d, L, C = 256, 4, num_classes Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=1e-3, weight_decay=0.01) for b in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=1e-3, weight_decay=0.01) head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=1e-3, weight_decay=0.01) all_sch = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] # Eval buffer xs, ys = [], [] for x, y in test_loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= 128: break x_eval = torch.cat(xs)[:128].to(device) y_eval = torch.cat(ys)[:128].to(device) def diagnose(): model.eval() with torch.no_grad(): _, hi = model(x_eval, return_hidden=True) h_L = hi[-1].norm(dim=-1).median().item() h0 = model.embed(x_eval) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) logits = model.out_head(model.out_ln(hs[-1])) loss = F.cross_entropy(logits, y_eval) grads = torch.autograd.grad(loss, hs) g_L = grads[-1].norm(dim=-1).median().item() acc = (logits.argmax(-1) == y_eval).float().mean().item() model.train() return h_L, g_L, acc log = [] h, g, a = diagnose() log.append({'epoch': 0, 'h_L': h, 'g_L': g, 'acc': a}) for ep in range(1, epochs + 1): model.train() for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 hL = hiddens[-1].detach() head_opt.zero_grad() F.cross_entropy(model.out_head(model.out_ln(hL)), y).backward() head_opt.step() for l in range(L): a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 f_l = model.blocks[l](hiddens[l].detach()) local_loss = (f_l * (a_dfa / rms)).sum(-1).mean() if lam > 0: local_loss = local_loss + lam * (f_l ** 2).sum(-1).mean() block_opts[l].zero_grad(); local_loss.backward(); block_opts[l].step() a0 = (e_T @ Bs[0].T).detach() rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_opt.zero_grad(); (h0 * (a0 / rms0)).sum(-1).mean().backward(); embed_opt.step() for s in all_sch: s.step() h, g, a = diagnose() log.append({'epoch': ep, 'h_L': h, 'g_L': g, 'acc': a}) if ep % 10 == 0 or ep == epochs: print(f" [lam={lam}] s={seed} ep {ep}: ||h_L||={h:.3e} ||g_L||={g:.3e} acc={a:.4f}", flush=True) return log, model, Bs def fresh_b_null(model, x_eval, y_eval, training_Bs, n_draws=20): """Fresh-B null calibration on a trained checkpoint.""" model.eval() d, L, C = 256, 4, len(training_Bs[0][0]) if training_Bs[0].dim() == 2 else 10 device = x_eval.device def deep_cos_with_Bs(Bs): h0 = model.embed(x_eval) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) logits = model.out_head(model.out_ln(hs[-1])) loss = F.cross_entropy(logits, y_eval) grads = torch.autograd.grad(loss, hs) with torch.no_grad(): e_T = logits.softmax(-1) e_T[torch.arange(x_eval.size(0)), y_eval] -= 1 cos_layers = [] for l in range(L): a = (e_T @ Bs[l].T).detach() cos_layers.append(cosine_similarity_batch(a, grads[l].detach())) return float(np.mean(cos_layers[1:])) # deep = exclude layer 0 train_cos = deep_cos_with_Bs(training_Bs) fresh_cos = [] for _ in range(n_draws): fresh_Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] fresh_cos.append(deep_cos_with_Bs(fresh_Bs)) return { 'training_Bs_deep_cos': train_cos, 'fresh_Bs_deep_mean': float(np.mean(fresh_cos)), 'fresh_Bs_deep_std_ddof1': float(np.std(fresh_cos, ddof=1)), 'n_draws': n_draws, } def main(): p = argparse.ArgumentParser() p.add_argument('--seeds', nargs='+', type=int, default=[42, 123, 456]) p.add_argument('--epochs', type=int, default=30) p.add_argument('--lambdas', nargs='+', type=float, default=[0.0, 1e-4, 1e-2]) p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/penalty_sweep') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') train_loader, test_loader, _ = get_data('cifar10', 128) results = {} for lam in args.lambdas: lam_key = f'lam_{lam}' results[lam_key] = {} for seed in args.seeds: print(f"\n=== lambda={lam}, seed={seed} ===", flush=True) log, model, Bs = train_dfa_trajectory(seed, train_loader, test_loader, device, args.epochs, lam) results[lam_key][str(seed)] = log # Fresh-B null on lambda=1e-2, seed=42 only if lam == 1e-2 and seed == 42: xs, ys = [], [] for x, y in test_loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= 128: break x_eval = torch.cat(xs)[:128].to(device) y_eval = torch.cat(ys)[:128].to(device) null = fresh_b_null(model, x_eval, y_eval, Bs) results['fresh_b_null'] = null print(f" Fresh-B: training={null['training_Bs_deep_cos']:+.4f}, " f"fresh={null['fresh_Bs_deep_mean']:+.4f} +/- {null['fresh_Bs_deep_std_ddof1']:.4f}") with open(os.path.join(args.output_dir, 'penalty_sweep.json'), 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved: {args.output_dir}/penalty_sweep.json") if __name__ == '__main__': main()