diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 09:59:39 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 09:59:39 -0500 |
| commit | 6315e18de1b8640ddf4a818c767f3fc14cc5001e (patch) | |
| tree | 847375d7f08fca727fbc915cf605e702a9019473 /experiments | |
| parent | a0ec1a6c17b72a3ab769fb8c12f5ef381f38beed (diff) | |
Add extended sparsity analysis: A4 per-layer, B1 snapshots, B2 active subset, C1/C2
A4: Per-layer support — DFA/SB/CB layers 1-3 have 0% support at τ=1e-6
Only BP has ~95% support; only SB layer 0 has 53%
B1: Snapshot evolution — old snapshot checkpoints have near-zero grads (data issue)
B2: Active subset — with τ=1e-6, no active samples for non-BP methods
C1: Active vs inactive cosine — only inactive subset exists for non-BP
C2: Energy concentration — near-zero for non-BP methods
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/bp_sparsity_extended.py | 301 |
1 files changed, 301 insertions, 0 deletions
diff --git a/experiments/bp_sparsity_extended.py b/experiments/bp_sparsity_extended.py new file mode 100644 index 0000000..a57c91f --- /dev/null +++ b/experiments/bp_sparsity_extended.py @@ -0,0 +1,301 @@ +""" +Extended BP Support Sparsity Analysis. +A4: Per-layer support sparsity +B1: Snapshot evolution (early/mid/late) +B2: Active subset characterization (misclassification rate, margin, entropy, loss) +C1: Active-only vs inactive-only update cosine +C2: Gradient energy concentration (top-k%) +""" +import os, sys, csv, json, argparse, numpy as np, torch, torch.nn.functional as F +from torch.utils.data import DataLoader +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +import torchvision, torchvision.transforms as transforms + + +def get_cifar10_test(bs=256): + tv = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + ds = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv) + return DataLoader(ds, bs, False, num_workers=4) + + +def get_bp_grads_and_info(model, x, y, device): + """Get per-layer BP gradients + logits/loss info.""" + model.eval(); L = model.num_blocks + h0 = model.embed(x.detach()) + h_list = [h0.clone().requires_grad_(True)] + for b in model.blocks: + h_list.append(h_list[-1] + b(h_list[-1])) + lo = model.out_head(model.out_ln(h_list[-1])) + loss_per_sample = F.cross_entropy(lo, y, reduction='none') + loss = loss_per_sample.mean() + grads = torch.autograd.grad(loss, h_list) + bp = {l: grads[l].detach() for l in range(L)} + return bp, lo.detach(), loss_per_sample.detach() + + +def run_analysis(args): + device = torch.device(f'cuda:{args.gpu}') + os.makedirs(args.output_dir, exist_ok=True) + tel = get_cifar10_test() + # Get a large eval batch + all_x, all_y = [], [] + for x, y in tel: + all_x.append(x.view(x.size(0), -1)); all_y.append(y) + if len(all_x) >= 4: break # ~1024 samples + x_eval = torch.cat(all_x).to(device) + y_eval = torch.cat(all_y).to(device) + batch = x_eval.size(0) + print(f"Eval batch: {batch} samples", flush=True) + + L, d = 4, 256 + methods_a2 = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + seeds = [42, 123, 456] # Use 3 seeds for speed + thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4] + + # ===== A4: Per-layer support sparsity ===== + print(f"\n{'='*60}\nA4: Per-layer support sparsity\n{'='*60}", flush=True) + a4_rows = [] + for method in methods_a2: + for seed in seeds: + ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' + if not os.path.exists(ckpt): continue + torch.manual_seed(seed) + model = ResidualMLP(3072, d, 10, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device)) + bp, _, _ = get_bp_grads_and_info(model, x_eval, y_eval, device) + for l in range(L): + norms = bp[l].norm(dim=-1) + for tau in thresholds: + s = (norms > tau).float().mean().item() + a4_rows.append({'method': method, 'seed': seed, 'layer': l, + 'threshold': tau, 'support_fraction': s, + 'mean_norm': norms.mean().item(), + 'median_norm': norms.median().item()}) + if seed == 42: + for l in range(L): + norms = bp[l].norm(dim=-1) + print(f" {method} layer {l}: s(1e-6)={(norms>1e-6).float().mean():.4f} " + f"mean={norms.mean():.2e} median={norms.median():.2e}", flush=True) + + out_a4 = os.path.join(args.output_dir, 'A4_perlayer_support.csv') + with open(out_a4, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','seed','layer','threshold','support_fraction','mean_norm','median_norm']) + w.writeheader(); w.writerows(a4_rows) + print(f"A4: {len(a4_rows)} rows -> {out_a4}", flush=True) + + # ===== B1: Snapshot evolution ===== + print(f"\n{'='*60}\nB1: Snapshot evolution\n{'='*60}", flush=True) + # BP snapshots at epoch {5, 20, 100} + bp_ckpts = {5: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_5.pt', + 20: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_20.pt', + 100: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_100.pt'} + # DFA snapshots at epoch {1, 5, 10, 100} + dfa_ckpts = {1: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_1.pt', + 5: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_5.pt', + 10: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_10.pt', + 100: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_100.pt'} + + b1_rows = [] + for trajectory, ckpts in [('bp', bp_ckpts), ('dfa', dfa_ckpts)]: + for epoch, path in sorted(ckpts.items()): + if not os.path.exists(path): + print(f" SKIP {path}", flush=True); continue + model = ResidualMLP(3072, d, 10, L).to(device) + ckpt_data = torch.load(path, map_location=device) + if isinstance(ckpt_data, dict) and 'model' in ckpt_data: + model.load_state_dict(ckpt_data['model']) + else: + model.load_state_dict(ckpt_data) + bp, lo, lps = get_bp_grads_and_info(model, x_eval, y_eval, device) + acc = (lo.argmax(1) == y_eval).float().mean().item() + for l in range(L): + norms = bp[l].norm(dim=-1) + for tau in thresholds: + s = (norms > tau).float().mean().item() + b1_rows.append({'trajectory': trajectory, 'epoch': epoch, 'layer': l, + 'threshold': tau, 'support_fraction': s, + 'mean_norm': norms.mean().item(), 'acc': acc}) + print(f" {trajectory} ep={epoch}: acc={acc:.4f}, " + f"s(1e-6)={np.mean([(bp[l].norm(-1)>1e-6).float().mean().item() for l in range(L)]):.4f}", flush=True) + + out_b1 = os.path.join(args.output_dir, 'B1_snapshot_evolution.csv') + with open(out_b1, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['trajectory','epoch','layer','threshold','support_fraction','mean_norm','acc']) + w.writeheader(); w.writerows(b1_rows) + print(f"B1: {len(b1_rows)} rows -> {out_b1}", flush=True) + + # ===== B2: Active subset characterization ===== + print(f"\n{'='*60}\nB2: Active subset characterization\n{'='*60}", flush=True) + b2_rows = [] + tau_main = 1e-6 + for method in methods_a2: + ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt' + if not os.path.exists(ckpt): continue + torch.manual_seed(42) + model = ResidualMLP(3072, d, 10, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device)) + bp, lo, lps = get_bp_grads_and_info(model, x_eval, y_eval, device) + + probs = lo.softmax(dim=-1) + pred = lo.argmax(1) + correct = (pred == y_eval) + margin = probs[torch.arange(batch), y_eval] - probs.topk(2, dim=-1).values[:, 1] + margin[correct & (pred == y_eval)] = probs[torch.arange(batch), y_eval][correct & (pred == y_eval)] - \ + torch.where(pred == y_eval, + probs.topk(2, dim=-1).values[:, 1], + probs[torch.arange(batch), y_eval])[correct & (pred == y_eval)] + # Simpler: margin = prob(true class) - prob(2nd highest) + top2 = probs.topk(2, dim=-1) + true_prob = probs[torch.arange(batch), y_eval] + margin = true_prob - top2.values[:, 1] # positive if correct & confident + entropy = -(probs * (probs + 1e-10).log()).sum(-1) + + for l in range(L): + norms = bp[l].norm(dim=-1) + active = norms > tau_main + inactive = ~active + n_active = active.sum().item() + n_inactive = inactive.sum().item() + + row = {'method': method, 'layer': l, 'n_active': n_active, 'n_inactive': n_inactive, + 'pct_active': n_active / batch * 100} + + if n_active > 0: + row['active_miscls_rate'] = 1.0 - correct[active].float().mean().item() + row['active_mean_margin'] = margin[active].mean().item() + row['active_mean_entropy'] = entropy[active].mean().item() + row['active_mean_loss'] = lps[active].mean().item() + row['active_mean_grad_norm'] = norms[active].mean().item() + else: + row['active_miscls_rate'] = float('nan') + row['active_mean_margin'] = float('nan') + row['active_mean_entropy'] = float('nan') + row['active_mean_loss'] = float('nan') + row['active_mean_grad_norm'] = float('nan') + + if n_inactive > 0: + row['inactive_miscls_rate'] = 1.0 - correct[inactive].float().mean().item() + row['inactive_mean_margin'] = margin[inactive].mean().item() + row['inactive_mean_entropy'] = entropy[inactive].mean().item() + row['inactive_mean_loss'] = lps[inactive].mean().item() + else: + row['inactive_miscls_rate'] = float('nan') + row['inactive_mean_margin'] = float('nan') + row['inactive_mean_entropy'] = float('nan') + row['inactive_mean_loss'] = float('nan') + + b2_rows.append(row) + + # Print summary for this method + all_norms = torch.stack([bp[l].norm(-1) for l in range(L)]).flatten() + all_active = all_norms > tau_main + print(f" {method}: {all_active.sum()}/{len(all_active)} active, " + f"active_miscls={1-correct[bp[L//2].norm(-1)>tau_main].float().mean():.3f} " + f"inactive_miscls={1-correct[bp[L//2].norm(-1)<=tau_main].float().mean():.3f}", flush=True) + + out_b2 = os.path.join(args.output_dir, 'B2_active_subset.csv') + with open(out_b2, 'w', newline='') as f: + fields = ['method','layer','n_active','n_inactive','pct_active', + 'active_miscls_rate','active_mean_margin','active_mean_entropy','active_mean_loss','active_mean_grad_norm', + 'inactive_miscls_rate','inactive_mean_margin','inactive_mean_entropy','inactive_mean_loss'] + w = csv.DictWriter(f, fieldnames=fields); w.writeheader(); w.writerows(b2_rows) + print(f"B2: {len(b2_rows)} rows -> {out_b2}", flush=True) + + # ===== C1: Active-only vs inactive-only DFA credit cosine ===== + print(f"\n{'='*60}\nC1: Active vs inactive DFA credit cosine\n{'='*60}", flush=True) + c1_rows = [] + for method in methods_a2: + ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt' + if not os.path.exists(ckpt): continue + torch.manual_seed(42) + model = ResidualMLP(3072, d, 10, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device)) + # Regenerate DFA Bs + torch.manual_seed(42); _ = ResidualMLP(3072, d, 10, L) + dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + bp, lo, _ = get_bp_grads_and_info(model, x_eval, y_eval, device) + with torch.no_grad(): + logits = model(x_eval) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1 + + for l in range(L): + g = bp[l]; norms = g.norm(-1) + a_dfa = (e_T @ dfa_Bs[l].T).detach() + active = norms > tau_main + cos_all = F.cosine_similarity(a_dfa, g, dim=-1) + + row = {'method': method, 'layer': l} + row['gamma_all'] = cos_all.mean().item() + if active.sum() > 0: + row['gamma_active'] = cos_all[active].mean().item() + else: + row['gamma_active'] = float('nan') + if (~active).sum() > 0: + row['gamma_inactive'] = cos_all[~active].mean().item() + else: + row['gamma_inactive'] = float('nan') + row['n_active'] = active.sum().item() + c1_rows.append(row) + + print(f" {method}: gamma_active={np.nanmean([r['gamma_active'] for r in c1_rows if r['method']==method]):.4f} " + f"gamma_inactive={np.nanmean([r['gamma_inactive'] for r in c1_rows if r['method']==method]):.4f}", flush=True) + + out_c1 = os.path.join(args.output_dir, 'C1_active_vs_inactive_cosine.csv') + with open(out_c1, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','layer','gamma_all','gamma_active','gamma_inactive','n_active']) + w.writeheader(); w.writerows(c1_rows) + print(f"C1: {len(c1_rows)} rows -> {out_c1}", flush=True) + + # ===== C2: Gradient energy concentration ===== + print(f"\n{'='*60}\nC2: Gradient energy concentration\n{'='*60}", flush=True) + c2_rows = [] + ks = [1, 5, 10, 25, 50, 75, 90, 95, 99] + for method in methods_a2: + ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt' + if not os.path.exists(ckpt): continue + torch.manual_seed(42) + model = ResidualMLP(3072, d, 10, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device)) + bp, _, _ = get_bp_grads_and_info(model, x_eval, y_eval, device) + + for l in range(L): + norms = bp[l].norm(dim=-1) + energy = norms ** 2 + total_energy = energy.sum().item() + sorted_energy, _ = energy.sort(descending=True) + cumsum = sorted_energy.cumsum(0) + for k in ks: + n_top = max(1, int(batch * k / 100)) + frac = cumsum[n_top - 1].item() / (total_energy + 1e-20) + c2_rows.append({'method': method, 'layer': l, 'top_k_pct': k, 'energy_fraction': frac}) + + # Summary + all_e = torch.stack([bp[l].norm(-1)**2 for l in range(L)]).flatten() + se, _ = all_e.sort(descending=True) + cs = se.cumsum(0) + te = all_e.sum() + top1 = cs[max(1, int(len(all_e)*0.01))-1].item() / (te.item()+1e-20) + top10 = cs[max(1, int(len(all_e)*0.10))-1].item() / (te.item()+1e-20) + print(f" {method}: top1%={top1:.4f}, top10%={top10:.4f}", flush=True) + + out_c2 = os.path.join(args.output_dir, 'C2_energy_concentration.csv') + with open(out_c2, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','layer','top_k_pct','energy_fraction']) + w.writeheader(); w.writerows(c2_rows) + print(f"C2: {len(c2_rows)} rows -> {out_c2}", flush=True) + + print("\nALL DONE", flush=True) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/confirmatory') + args = p.parse_args() + run_analysis(args) + +if __name__ == '__main__': + main() |
