""" Extended BP Support Sparsity Analysis. A4: Per-layer support sparsity B1: Snapshot evolution (early/mid/late) B2: Active subset characterization (misclassification rate, margin, entropy, loss) C1: Active-only vs inactive-only update cosine C2: Gradient energy concentration (top-k%) """ import os, sys, csv, json, argparse, numpy as np, torch, torch.nn.functional as F from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP import torchvision, torchvision.transforms as transforms def get_cifar10_test(bs=256): tv = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) ds = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) return DataLoader(ds, bs, False, num_workers=4) def get_bp_grads_and_info(model, x, y, device): """Get per-layer BP gradients + logits/loss info.""" model.eval(); L = model.num_blocks h0 = model.embed(x.detach()) h_list = [h0.clone().requires_grad_(True)] for b in model.blocks: h_list.append(h_list[-1] + b(h_list[-1])) lo = model.out_head(model.out_ln(h_list[-1])) loss_per_sample = F.cross_entropy(lo, y, reduction='none') loss = loss_per_sample.mean() grads = torch.autograd.grad(loss, h_list) bp = {l: grads[l].detach() for l in range(L)} return bp, lo.detach(), loss_per_sample.detach() def run_analysis(args): device = torch.device(f'cuda:{args.gpu}') os.makedirs(args.output_dir, exist_ok=True) tel = get_cifar10_test() # Get a large eval batch all_x, all_y = [], [] for x, y in tel: all_x.append(x.view(x.size(0), -1)); all_y.append(y) if len(all_x) >= 4: break # ~1024 samples x_eval = torch.cat(all_x).to(device) y_eval = torch.cat(all_y).to(device) batch = x_eval.size(0) print(f"Eval batch: {batch} samples", flush=True) L, d = 4, 256 methods_a2 = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] seeds = [42, 123, 456] # Use 3 seeds for speed thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4] # ===== A4: Per-layer support sparsity ===== print(f"\n{'='*60}\nA4: Per-layer support sparsity\n{'='*60}", flush=True) a4_rows = [] for method in methods_a2: for seed in seeds: ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' if not os.path.exists(ckpt): continue torch.manual_seed(seed) model = ResidualMLP(3072, d, 10, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) bp, _, _ = get_bp_grads_and_info(model, x_eval, y_eval, device) for l in range(L): norms = bp[l].norm(dim=-1) for tau in thresholds: s = (norms > tau).float().mean().item() a4_rows.append({'method': method, 'seed': seed, 'layer': l, 'threshold': tau, 'support_fraction': s, 'mean_norm': norms.mean().item(), 'median_norm': norms.median().item()}) if seed == 42: for l in range(L): norms = bp[l].norm(dim=-1) print(f" {method} layer {l}: s(1e-6)={(norms>1e-6).float().mean():.4f} " f"mean={norms.mean():.2e} median={norms.median():.2e}", flush=True) out_a4 = os.path.join(args.output_dir, 'A4_perlayer_support.csv') with open(out_a4, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['method','seed','layer','threshold','support_fraction','mean_norm','median_norm']) w.writeheader(); w.writerows(a4_rows) print(f"A4: {len(a4_rows)} rows -> {out_a4}", flush=True) # ===== B1: Snapshot evolution ===== print(f"\n{'='*60}\nB1: Snapshot evolution\n{'='*60}", flush=True) # BP snapshots at epoch {5, 20, 100} bp_ckpts = {5: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_5.pt', 20: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_20.pt', 100: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_100.pt'} # DFA snapshots at epoch {1, 5, 10, 100} dfa_ckpts = {1: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_1.pt', 5: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_5.pt', 10: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_10.pt', 100: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_100.pt'} b1_rows = [] for trajectory, ckpts in [('bp', bp_ckpts), ('dfa', dfa_ckpts)]: for epoch, path in sorted(ckpts.items()): if not os.path.exists(path): print(f" SKIP {path}", flush=True); continue model = ResidualMLP(3072, d, 10, L).to(device) ckpt_data = torch.load(path, map_location=device) if isinstance(ckpt_data, dict) and 'model' in ckpt_data: model.load_state_dict(ckpt_data['model']) else: model.load_state_dict(ckpt_data) bp, lo, lps = get_bp_grads_and_info(model, x_eval, y_eval, device) acc = (lo.argmax(1) == y_eval).float().mean().item() for l in range(L): norms = bp[l].norm(dim=-1) for tau in thresholds: s = (norms > tau).float().mean().item() b1_rows.append({'trajectory': trajectory, 'epoch': epoch, 'layer': l, 'threshold': tau, 'support_fraction': s, 'mean_norm': norms.mean().item(), 'acc': acc}) print(f" {trajectory} ep={epoch}: acc={acc:.4f}, " f"s(1e-6)={np.mean([(bp[l].norm(-1)>1e-6).float().mean().item() for l in range(L)]):.4f}", flush=True) out_b1 = os.path.join(args.output_dir, 'B1_snapshot_evolution.csv') with open(out_b1, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['trajectory','epoch','layer','threshold','support_fraction','mean_norm','acc']) w.writeheader(); w.writerows(b1_rows) print(f"B1: {len(b1_rows)} rows -> {out_b1}", flush=True) # ===== B2: Active subset characterization ===== print(f"\n{'='*60}\nB2: Active subset characterization\n{'='*60}", flush=True) b2_rows = [] tau_main = 1e-6 for method in methods_a2: ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt' if not os.path.exists(ckpt): continue torch.manual_seed(42) model = ResidualMLP(3072, d, 10, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) bp, lo, lps = get_bp_grads_and_info(model, x_eval, y_eval, device) probs = lo.softmax(dim=-1) pred = lo.argmax(1) correct = (pred == y_eval) margin = probs[torch.arange(batch), y_eval] - probs.topk(2, dim=-1).values[:, 1] margin[correct & (pred == y_eval)] = probs[torch.arange(batch), y_eval][correct & (pred == y_eval)] - \ torch.where(pred == y_eval, probs.topk(2, dim=-1).values[:, 1], probs[torch.arange(batch), y_eval])[correct & (pred == y_eval)] # Simpler: margin = prob(true class) - prob(2nd highest) top2 = probs.topk(2, dim=-1) true_prob = probs[torch.arange(batch), y_eval] margin = true_prob - top2.values[:, 1] # positive if correct & confident entropy = -(probs * (probs + 1e-10).log()).sum(-1) for l in range(L): norms = bp[l].norm(dim=-1) active = norms > tau_main inactive = ~active n_active = active.sum().item() n_inactive = inactive.sum().item() row = {'method': method, 'layer': l, 'n_active': n_active, 'n_inactive': n_inactive, 'pct_active': n_active / batch * 100} if n_active > 0: row['active_miscls_rate'] = 1.0 - correct[active].float().mean().item() row['active_mean_margin'] = margin[active].mean().item() row['active_mean_entropy'] = entropy[active].mean().item() row['active_mean_loss'] = lps[active].mean().item() row['active_mean_grad_norm'] = norms[active].mean().item() else: row['active_miscls_rate'] = float('nan') row['active_mean_margin'] = float('nan') row['active_mean_entropy'] = float('nan') row['active_mean_loss'] = float('nan') row['active_mean_grad_norm'] = float('nan') if n_inactive > 0: row['inactive_miscls_rate'] = 1.0 - correct[inactive].float().mean().item() row['inactive_mean_margin'] = margin[inactive].mean().item() row['inactive_mean_entropy'] = entropy[inactive].mean().item() row['inactive_mean_loss'] = lps[inactive].mean().item() else: row['inactive_miscls_rate'] = float('nan') row['inactive_mean_margin'] = float('nan') row['inactive_mean_entropy'] = float('nan') row['inactive_mean_loss'] = float('nan') b2_rows.append(row) # Print summary for this method all_norms = torch.stack([bp[l].norm(-1) for l in range(L)]).flatten() all_active = all_norms > tau_main print(f" {method}: {all_active.sum()}/{len(all_active)} active, " f"active_miscls={1-correct[bp[L//2].norm(-1)>tau_main].float().mean():.3f} " f"inactive_miscls={1-correct[bp[L//2].norm(-1)<=tau_main].float().mean():.3f}", flush=True) out_b2 = os.path.join(args.output_dir, 'B2_active_subset.csv') with open(out_b2, 'w', newline='') as f: fields = ['method','layer','n_active','n_inactive','pct_active', 'active_miscls_rate','active_mean_margin','active_mean_entropy','active_mean_loss','active_mean_grad_norm', 'inactive_miscls_rate','inactive_mean_margin','inactive_mean_entropy','inactive_mean_loss'] w = csv.DictWriter(f, fieldnames=fields); w.writeheader(); w.writerows(b2_rows) print(f"B2: {len(b2_rows)} rows -> {out_b2}", flush=True) # ===== C1: Active-only vs inactive-only DFA credit cosine ===== print(f"\n{'='*60}\nC1: Active vs inactive DFA credit cosine\n{'='*60}", flush=True) c1_rows = [] for method in methods_a2: ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt' if not os.path.exists(ckpt): continue torch.manual_seed(42) model = ResidualMLP(3072, d, 10, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) # Regenerate DFA Bs torch.manual_seed(42); _ = ResidualMLP(3072, d, 10, L) dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] bp, lo, _ = get_bp_grads_and_info(model, x_eval, y_eval, device) with torch.no_grad(): logits = model(x_eval) e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1 for l in range(L): g = bp[l]; norms = g.norm(-1) a_dfa = (e_T @ dfa_Bs[l].T).detach() active = norms > tau_main cos_all = F.cosine_similarity(a_dfa, g, dim=-1) row = {'method': method, 'layer': l} row['gamma_all'] = cos_all.mean().item() if active.sum() > 0: row['gamma_active'] = cos_all[active].mean().item() else: row['gamma_active'] = float('nan') if (~active).sum() > 0: row['gamma_inactive'] = cos_all[~active].mean().item() else: row['gamma_inactive'] = float('nan') row['n_active'] = active.sum().item() c1_rows.append(row) print(f" {method}: gamma_active={np.nanmean([r['gamma_active'] for r in c1_rows if r['method']==method]):.4f} " f"gamma_inactive={np.nanmean([r['gamma_inactive'] for r in c1_rows if r['method']==method]):.4f}", flush=True) out_c1 = os.path.join(args.output_dir, 'C1_active_vs_inactive_cosine.csv') with open(out_c1, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['method','layer','gamma_all','gamma_active','gamma_inactive','n_active']) w.writeheader(); w.writerows(c1_rows) print(f"C1: {len(c1_rows)} rows -> {out_c1}", flush=True) # ===== C2: Gradient energy concentration ===== print(f"\n{'='*60}\nC2: Gradient energy concentration\n{'='*60}", flush=True) c2_rows = [] ks = [1, 5, 10, 25, 50, 75, 90, 95, 99] for method in methods_a2: ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt' if not os.path.exists(ckpt): continue torch.manual_seed(42) model = ResidualMLP(3072, d, 10, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) bp, _, _ = get_bp_grads_and_info(model, x_eval, y_eval, device) for l in range(L): norms = bp[l].norm(dim=-1) energy = norms ** 2 total_energy = energy.sum().item() sorted_energy, _ = energy.sort(descending=True) cumsum = sorted_energy.cumsum(0) for k in ks: n_top = max(1, int(batch * k / 100)) frac = cumsum[n_top - 1].item() / (total_energy + 1e-20) c2_rows.append({'method': method, 'layer': l, 'top_k_pct': k, 'energy_fraction': frac}) # Summary all_e = torch.stack([bp[l].norm(-1)**2 for l in range(L)]).flatten() se, _ = all_e.sort(descending=True) cs = se.cumsum(0) te = all_e.sum() top1 = cs[max(1, int(len(all_e)*0.01))-1].item() / (te.item()+1e-20) top10 = cs[max(1, int(len(all_e)*0.10))-1].item() / (te.item()+1e-20) print(f" {method}: top1%={top1:.4f}, top10%={top10:.4f}", flush=True) out_c2 = os.path.join(args.output_dir, 'C2_energy_concentration.csv') with open(out_c2, 'w', newline='') as f: w = csv.DictWriter(f, fieldnames=['method','layer','top_k_pct','energy_fraction']) w.writeheader(); w.writerows(c2_rows) print(f"C2: {len(c2_rows)} rows -> {out_c2}", flush=True) print("\nALL DONE", flush=True) def main(): p = argparse.ArgumentParser() p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/confirmatory') args = p.parse_args() run_analysis(args) if __name__ == '__main__': main()