""" Confirmatory supplement: all from existing checkpoints, no retraining. Task 1: CIFAR full metrics (Gamma_raw, Gamma_filtered, rho, acc, naive_StateErr, StateErr) Task 2: Support sparsity (5 thresholds × 4 methods × 10 seeds × 4 layers) Task 3: Per-layer gradient norm distribution (percentiles) Task 4: Active-subset Gamma for BP and DFA """ import os, sys, csv, json, argparse, numpy as np, torch, torch.nn.functional as F from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from metrics.credit_metrics import perturbation_correlation import torchvision, torchvision.transforms as transforms def get_test_batch(device, n_batches=4): tv = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), 256, False, num_workers=4) xs, ys = [], [] for x, y in tel: xs.append(x.view(x.size(0), -1)); ys.append(y) if len(xs) >= n_batches: break return torch.cat(xs).to(device), torch.cat(ys).to(device) def get_bp_grads(model, x, y, device): model.eval(); L = model.num_blocks h0 = model.embed(x.detach()) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) lo = model.out_head(model.out_ln(hs[-1])) loss = F.cross_entropy(lo, y) gs = torch.autograd.grad(loss, hs) return {l: gs[l].detach() for l in range(L)}, lo.detach(), F.cross_entropy(lo, y, reduction='none').detach() def get_dfa_Bs(seed, d, C, L, device): """Regenerate DFA Bs with exact same seed sequence as training.""" torch.manual_seed(seed) _ = ResidualMLP(3072, d, C, L) # consume same random state as model init return [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] def run(args): device = torch.device(f'cuda:{args.gpu}') os.makedirs(args.output_dir, exist_ok=True) x_eval, y_eval = get_test_batch(device) batch = x_eval.size(0) print(f"Eval: {batch} samples", flush=True) L, d, C = 4, 256, 10 seeds = [42, 123, 456, 789, 1024, 2048, 3000, 4000, 5000, 6000] methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4] # ===== Task 1: Full metrics ===== print(f"\n{'='*60}\nTask 1: Full metrics\n{'='*60}", flush=True) t1_rows = [] for method in methods: for seed in seeds: ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' if not os.path.exists(ckpt): print(f" SKIP {ckpt}", flush=True); continue torch.manual_seed(seed) model = ResidualMLP(3072, d, C, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) bp, lo, lps = get_bp_grads(model, x_eval, y_eval, device) # Accuracy acc = (lo.argmax(1) == y_eval).float().mean().item() # Naive StateErr with torch.no_grad(): _, hi = model(x_eval, return_hidden=True) h_mid = hi[L//2]; h_L = hi[-1] nse = ((h_mid - h_L).norm(-1) / h_L.norm(-1).clamp(min=1e-8)).mean().item() # DFA Bs for Gamma computation dfa_Bs = get_dfa_Bs(seed, d, C, L, device) # e_T with torch.no_grad(): logits = model(x_eval) e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1 # Per-layer metrics gamma_raw_list, gamma_filt_list, rho_list = [], [], [] for l in range(L): g = bp[l]; norms = g.norm(-1); mask = norms > 1e-6 a_dfa = (e_T @ dfa_Bs[l].T).detach() h_l = hi[l].detach() # Gamma raw & filtered (DFA vs BP) if method == 'bp': gamma_raw_list.append(1.0) gamma_filt_list.append(1.0) else: cos = F.cosine_similarity(a_dfa, g, dim=-1) gamma_raw_list.append(cos.mean().item()) gamma_filt_list.append(cos[mask].mean().item() if mask.sum() > 0 else float('nan')) # Rho (perturbation correlation) # Use method-appropriate credit for rho if method == 'bp': a_l = g else: a_l = a_dfa # Use DFA credit for all non-BP (closest available) def make_fwd(sl): def f(h): with torch.no_grad(): c = h for i in range(sl, L): c = c + model.blocks[i](c) return F.cross_entropy(model.out_head(model.out_ln(c)), y_eval, reduction='none') return f rho = perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16) rho_list.append(rho) row = { 'method': method, 'seed': seed, 'acc': acc, 'naive_StateErr': nse, 'Gamma_raw': np.mean(gamma_raw_list), 'Gamma_filtered': np.nanmean(gamma_filt_list), 'rho': np.mean(rho_list), 'mean_bp_grad_norm': np.mean([bp[l].norm(-1).mean().item() for l in range(L)]), } t1_rows.append(row) if seed in [42, 123]: print(f" {method} s={seed}: acc={acc:.4f} Gr={row['Gamma_raw']:.4f} " f"Gf={row['Gamma_filtered']:.4f} rho={row['rho']:.4f} nse={nse:.4f}", flush=True) out1 = os.path.join(args.output_dir, 'T1_cifar_full_metrics.csv') with open(out1, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['method','seed','acc','naive_StateErr','Gamma_raw','Gamma_filtered','rho','mean_bp_grad_norm']) w.writeheader(); w.writerows(t1_rows) print(f"Task 1: {len(t1_rows)} rows -> {out1}", flush=True) # ===== Task 2: Support sparsity ===== print(f"\n{'='*60}\nTask 2: Support sparsity\n{'='*60}", flush=True) t2_rows = [] for method in methods: for seed in seeds: ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' if not os.path.exists(ckpt): continue torch.manual_seed(seed) model = ResidualMLP(3072, d, C, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) bp, _, _ = get_bp_grads(model, x_eval, y_eval, device) for l in range(L): norms = bp[l].norm(-1) for tau in thresholds: t2_rows.append({ 'method': method, 'seed': seed, 'layer': l, 'threshold': tau, 'support_fraction': (norms > tau).float().mean().item(), 'mean_norm': norms.mean().item(), 'median_norm': norms.median().item() }) print(f" {method}: done ({len(seeds)} seeds)", flush=True) out2 = os.path.join(args.output_dir, 'T2_support_sparsity.csv') with open(out2, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['method','seed','layer','threshold','support_fraction','mean_norm','median_norm']) w.writeheader(); w.writerows(t2_rows) print(f"Task 2: {len(t2_rows)} rows -> {out2}", flush=True) # ===== Task 3: Gradient norm distribution ===== print(f"\n{'='*60}\nTask 3: Gradient norm distribution\n{'='*60}", flush=True) t3_rows = [] percentiles = [1, 5, 10, 25, 50, 75, 90, 95, 99] for method in methods: for seed in seeds[:3]: # 3 seeds for distributions ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' if not os.path.exists(ckpt): continue torch.manual_seed(seed) model = ResidualMLP(3072, d, C, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) bp, _, _ = get_bp_grads(model, x_eval, y_eval, device) for l in range(L): norms = bp[l].norm(-1).cpu().numpy() log_norms = np.log10(norms.clip(min=1e-20)) row = {'method': method, 'seed': seed, 'layer': l, 'mean': float(norms.mean()), 'std': float(norms.std()), 'mean_log10': float(log_norms.mean()), 'std_log10': float(log_norms.std())} for p in percentiles: row[f'p{p}'] = float(np.percentile(norms, p)) row[f'p{p}_log10'] = float(np.percentile(log_norms, p)) t3_rows.append(row) print(f" {method}: done", flush=True) out3 = os.path.join(args.output_dir, 'T3_grad_norm_distribution.csv') fields3 = ['method','seed','layer','mean','std','mean_log10','std_log10'] + \ [f'p{p}' for p in percentiles] + [f'p{p}_log10' for p in percentiles] with open(out3, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=fields3); w.writeheader(); w.writerows(t3_rows) print(f"Task 3: {len(t3_rows)} rows -> {out3}", flush=True) # ===== Task 4: Active-subset Gamma ===== print(f"\n{'='*60}\nTask 4: Active-subset Gamma\n{'='*60}", flush=True) t4_rows = [] for tau in thresholds: for method in ['bp', 'dfa']: for seed in seeds: ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' if not os.path.exists(ckpt): continue torch.manual_seed(seed) model = ResidualMLP(3072, d, C, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) bp, lo, _ = get_bp_grads(model, x_eval, y_eval, device) dfa_Bs = get_dfa_Bs(seed, d, C, L, device) with torch.no_grad(): logits = model(x_eval) e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1 gamma_active_list, gamma_energy_list, n_active_list = [], [], [] for l in range(L): g = bp[l]; norms = g.norm(-1); mask = norms > tau if method == 'bp': cos = torch.ones(batch, device=device) else: a_dfa = (e_T @ dfa_Bs[l].T).detach() cos = F.cosine_similarity(a_dfa, g, dim=-1) # Active-subset Gamma if mask.sum() > 0: gamma_active_list.append(cos[mask].mean().item()) else: gamma_active_list.append(float('nan')) # Energy-weighted Gamma weights = norms ** 2 if weights.sum() > 0: gamma_energy_list.append((cos * weights).sum().item() / (weights.sum().item() + 1e-20)) else: gamma_energy_list.append(float('nan')) n_active_list.append(mask.sum().item()) t4_rows.append({ 'method': method, 'seed': seed, 'threshold': tau, 'Gamma_active': np.nanmean(gamma_active_list), 'Gamma_energy_weighted': np.nanmean(gamma_energy_list), 'mean_n_active': np.mean(n_active_list), 'pct_active': np.mean(n_active_list) / batch * 100, }) # Summary for this threshold for m in ['bp', 'dfa']: vals = [r for r in t4_rows if r['method']==m and r['threshold']==tau] if vals: ga = np.nanmean([r['Gamma_active'] for r in vals]) ge = np.nanmean([r['Gamma_energy_weighted'] for r in vals]) pa = np.mean([r['pct_active'] for r in vals]) print(f" tau={tau:.0e} {m}: Gamma_active={ga:.4f} Gamma_energy={ge:.4f} pct_active={pa:.1f}%", flush=True) out4 = os.path.join(args.output_dir, 'T4_active_subset_gamma.csv') with open(out4, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['method','seed','threshold','Gamma_active','Gamma_energy_weighted','mean_n_active','pct_active']) w.writeheader(); w.writerows(t4_rows) print(f"Task 4: {len(t4_rows)} rows -> {out4}", flush=True) print("\nALL TASKS DONE", flush=True) def main(): p = argparse.ArgumentParser() p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/confirmatory') args = p.parse_args() run(args) if __name__ == '__main__': main()