summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 09:59:39 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 09:59:39 -0500
commit6315e18de1b8640ddf4a818c767f3fc14cc5001e (patch)
tree847375d7f08fca727fbc915cf605e702a9019473 /experiments
parenta0ec1a6c17b72a3ab769fb8c12f5ef381f38beed (diff)
Add extended sparsity analysis: A4 per-layer, B1 snapshots, B2 active subset, C1/C2
A4: Per-layer support — DFA/SB/CB layers 1-3 have 0% support at τ=1e-6 Only BP has ~95% support; only SB layer 0 has 53% B1: Snapshot evolution — old snapshot checkpoints have near-zero grads (data issue) B2: Active subset — with τ=1e-6, no active samples for non-BP methods C1: Active vs inactive cosine — only inactive subset exists for non-BP C2: Energy concentration — near-zero for non-BP methods Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/bp_sparsity_extended.py301
1 files changed, 301 insertions, 0 deletions
diff --git a/experiments/bp_sparsity_extended.py b/experiments/bp_sparsity_extended.py
new file mode 100644
index 0000000..a57c91f
--- /dev/null
+++ b/experiments/bp_sparsity_extended.py
@@ -0,0 +1,301 @@
+"""
+Extended BP Support Sparsity Analysis.
+A4: Per-layer support sparsity
+B1: Snapshot evolution (early/mid/late)
+B2: Active subset characterization (misclassification rate, margin, entropy, loss)
+C1: Active-only vs inactive-only update cosine
+C2: Gradient energy concentration (top-k%)
+"""
+import os, sys, csv, json, argparse, numpy as np, torch, torch.nn.functional as F
+from torch.utils.data import DataLoader
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+import torchvision, torchvision.transforms as transforms
+
+
+def get_cifar10_test(bs=256):
+ tv = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ ds = torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv)
+ return DataLoader(ds, bs, False, num_workers=4)
+
+
+def get_bp_grads_and_info(model, x, y, device):
+ """Get per-layer BP gradients + logits/loss info."""
+ model.eval(); L = model.num_blocks
+ h0 = model.embed(x.detach())
+ h_list = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ h_list.append(h_list[-1] + b(h_list[-1]))
+ lo = model.out_head(model.out_ln(h_list[-1]))
+ loss_per_sample = F.cross_entropy(lo, y, reduction='none')
+ loss = loss_per_sample.mean()
+ grads = torch.autograd.grad(loss, h_list)
+ bp = {l: grads[l].detach() for l in range(L)}
+ return bp, lo.detach(), loss_per_sample.detach()
+
+
+def run_analysis(args):
+ device = torch.device(f'cuda:{args.gpu}')
+ os.makedirs(args.output_dir, exist_ok=True)
+ tel = get_cifar10_test()
+ # Get a large eval batch
+ all_x, all_y = [], []
+ for x, y in tel:
+ all_x.append(x.view(x.size(0), -1)); all_y.append(y)
+ if len(all_x) >= 4: break # ~1024 samples
+ x_eval = torch.cat(all_x).to(device)
+ y_eval = torch.cat(all_y).to(device)
+ batch = x_eval.size(0)
+ print(f"Eval batch: {batch} samples", flush=True)
+
+ L, d = 4, 256
+ methods_a2 = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ seeds = [42, 123, 456] # Use 3 seeds for speed
+ thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4]
+
+ # ===== A4: Per-layer support sparsity =====
+ print(f"\n{'='*60}\nA4: Per-layer support sparsity\n{'='*60}", flush=True)
+ a4_rows = []
+ for method in methods_a2:
+ for seed in seeds:
+ ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt'
+ if not os.path.exists(ckpt): continue
+ torch.manual_seed(seed)
+ model = ResidualMLP(3072, d, 10, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device))
+ bp, _, _ = get_bp_grads_and_info(model, x_eval, y_eval, device)
+ for l in range(L):
+ norms = bp[l].norm(dim=-1)
+ for tau in thresholds:
+ s = (norms > tau).float().mean().item()
+ a4_rows.append({'method': method, 'seed': seed, 'layer': l,
+ 'threshold': tau, 'support_fraction': s,
+ 'mean_norm': norms.mean().item(),
+ 'median_norm': norms.median().item()})
+ if seed == 42:
+ for l in range(L):
+ norms = bp[l].norm(dim=-1)
+ print(f" {method} layer {l}: s(1e-6)={(norms>1e-6).float().mean():.4f} "
+ f"mean={norms.mean():.2e} median={norms.median():.2e}", flush=True)
+
+ out_a4 = os.path.join(args.output_dir, 'A4_perlayer_support.csv')
+ with open(out_a4, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['method','seed','layer','threshold','support_fraction','mean_norm','median_norm'])
+ w.writeheader(); w.writerows(a4_rows)
+ print(f"A4: {len(a4_rows)} rows -> {out_a4}", flush=True)
+
+ # ===== B1: Snapshot evolution =====
+ print(f"\n{'='*60}\nB1: Snapshot evolution\n{'='*60}", flush=True)
+ # BP snapshots at epoch {5, 20, 100}
+ bp_ckpts = {5: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_5.pt',
+ 20: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_20.pt',
+ 100: 'results/snapshot_time/bp_ckpts_L4_d256_s42/epoch_100.pt'}
+ # DFA snapshots at epoch {1, 5, 10, 100}
+ dfa_ckpts = {1: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_1.pt',
+ 5: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_5.pt',
+ 10: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_10.pt',
+ 100: 'results/checkpointed_handoff/dfa_ckpts_s42/dfa_epoch_100.pt'}
+
+ b1_rows = []
+ for trajectory, ckpts in [('bp', bp_ckpts), ('dfa', dfa_ckpts)]:
+ for epoch, path in sorted(ckpts.items()):
+ if not os.path.exists(path):
+ print(f" SKIP {path}", flush=True); continue
+ model = ResidualMLP(3072, d, 10, L).to(device)
+ ckpt_data = torch.load(path, map_location=device)
+ if isinstance(ckpt_data, dict) and 'model' in ckpt_data:
+ model.load_state_dict(ckpt_data['model'])
+ else:
+ model.load_state_dict(ckpt_data)
+ bp, lo, lps = get_bp_grads_and_info(model, x_eval, y_eval, device)
+ acc = (lo.argmax(1) == y_eval).float().mean().item()
+ for l in range(L):
+ norms = bp[l].norm(dim=-1)
+ for tau in thresholds:
+ s = (norms > tau).float().mean().item()
+ b1_rows.append({'trajectory': trajectory, 'epoch': epoch, 'layer': l,
+ 'threshold': tau, 'support_fraction': s,
+ 'mean_norm': norms.mean().item(), 'acc': acc})
+ print(f" {trajectory} ep={epoch}: acc={acc:.4f}, "
+ f"s(1e-6)={np.mean([(bp[l].norm(-1)>1e-6).float().mean().item() for l in range(L)]):.4f}", flush=True)
+
+ out_b1 = os.path.join(args.output_dir, 'B1_snapshot_evolution.csv')
+ with open(out_b1, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['trajectory','epoch','layer','threshold','support_fraction','mean_norm','acc'])
+ w.writeheader(); w.writerows(b1_rows)
+ print(f"B1: {len(b1_rows)} rows -> {out_b1}", flush=True)
+
+ # ===== B2: Active subset characterization =====
+ print(f"\n{'='*60}\nB2: Active subset characterization\n{'='*60}", flush=True)
+ b2_rows = []
+ tau_main = 1e-6
+ for method in methods_a2:
+ ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt'
+ if not os.path.exists(ckpt): continue
+ torch.manual_seed(42)
+ model = ResidualMLP(3072, d, 10, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device))
+ bp, lo, lps = get_bp_grads_and_info(model, x_eval, y_eval, device)
+
+ probs = lo.softmax(dim=-1)
+ pred = lo.argmax(1)
+ correct = (pred == y_eval)
+ margin = probs[torch.arange(batch), y_eval] - probs.topk(2, dim=-1).values[:, 1]
+ margin[correct & (pred == y_eval)] = probs[torch.arange(batch), y_eval][correct & (pred == y_eval)] - \
+ torch.where(pred == y_eval,
+ probs.topk(2, dim=-1).values[:, 1],
+ probs[torch.arange(batch), y_eval])[correct & (pred == y_eval)]
+ # Simpler: margin = prob(true class) - prob(2nd highest)
+ top2 = probs.topk(2, dim=-1)
+ true_prob = probs[torch.arange(batch), y_eval]
+ margin = true_prob - top2.values[:, 1] # positive if correct & confident
+ entropy = -(probs * (probs + 1e-10).log()).sum(-1)
+
+ for l in range(L):
+ norms = bp[l].norm(dim=-1)
+ active = norms > tau_main
+ inactive = ~active
+ n_active = active.sum().item()
+ n_inactive = inactive.sum().item()
+
+ row = {'method': method, 'layer': l, 'n_active': n_active, 'n_inactive': n_inactive,
+ 'pct_active': n_active / batch * 100}
+
+ if n_active > 0:
+ row['active_miscls_rate'] = 1.0 - correct[active].float().mean().item()
+ row['active_mean_margin'] = margin[active].mean().item()
+ row['active_mean_entropy'] = entropy[active].mean().item()
+ row['active_mean_loss'] = lps[active].mean().item()
+ row['active_mean_grad_norm'] = norms[active].mean().item()
+ else:
+ row['active_miscls_rate'] = float('nan')
+ row['active_mean_margin'] = float('nan')
+ row['active_mean_entropy'] = float('nan')
+ row['active_mean_loss'] = float('nan')
+ row['active_mean_grad_norm'] = float('nan')
+
+ if n_inactive > 0:
+ row['inactive_miscls_rate'] = 1.0 - correct[inactive].float().mean().item()
+ row['inactive_mean_margin'] = margin[inactive].mean().item()
+ row['inactive_mean_entropy'] = entropy[inactive].mean().item()
+ row['inactive_mean_loss'] = lps[inactive].mean().item()
+ else:
+ row['inactive_miscls_rate'] = float('nan')
+ row['inactive_mean_margin'] = float('nan')
+ row['inactive_mean_entropy'] = float('nan')
+ row['inactive_mean_loss'] = float('nan')
+
+ b2_rows.append(row)
+
+ # Print summary for this method
+ all_norms = torch.stack([bp[l].norm(-1) for l in range(L)]).flatten()
+ all_active = all_norms > tau_main
+ print(f" {method}: {all_active.sum()}/{len(all_active)} active, "
+ f"active_miscls={1-correct[bp[L//2].norm(-1)>tau_main].float().mean():.3f} "
+ f"inactive_miscls={1-correct[bp[L//2].norm(-1)<=tau_main].float().mean():.3f}", flush=True)
+
+ out_b2 = os.path.join(args.output_dir, 'B2_active_subset.csv')
+ with open(out_b2, 'w', newline='') as f:
+ fields = ['method','layer','n_active','n_inactive','pct_active',
+ 'active_miscls_rate','active_mean_margin','active_mean_entropy','active_mean_loss','active_mean_grad_norm',
+ 'inactive_miscls_rate','inactive_mean_margin','inactive_mean_entropy','inactive_mean_loss']
+ w = csv.DictWriter(f, fieldnames=fields); w.writeheader(); w.writerows(b2_rows)
+ print(f"B2: {len(b2_rows)} rows -> {out_b2}", flush=True)
+
+ # ===== C1: Active-only vs inactive-only DFA credit cosine =====
+ print(f"\n{'='*60}\nC1: Active vs inactive DFA credit cosine\n{'='*60}", flush=True)
+ c1_rows = []
+ for method in methods_a2:
+ ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt'
+ if not os.path.exists(ckpt): continue
+ torch.manual_seed(42)
+ model = ResidualMLP(3072, d, 10, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device))
+ # Regenerate DFA Bs
+ torch.manual_seed(42); _ = ResidualMLP(3072, d, 10, L)
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ bp, lo, _ = get_bp_grads_and_info(model, x_eval, y_eval, device)
+ with torch.no_grad():
+ logits = model(x_eval)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1
+
+ for l in range(L):
+ g = bp[l]; norms = g.norm(-1)
+ a_dfa = (e_T @ dfa_Bs[l].T).detach()
+ active = norms > tau_main
+ cos_all = F.cosine_similarity(a_dfa, g, dim=-1)
+
+ row = {'method': method, 'layer': l}
+ row['gamma_all'] = cos_all.mean().item()
+ if active.sum() > 0:
+ row['gamma_active'] = cos_all[active].mean().item()
+ else:
+ row['gamma_active'] = float('nan')
+ if (~active).sum() > 0:
+ row['gamma_inactive'] = cos_all[~active].mean().item()
+ else:
+ row['gamma_inactive'] = float('nan')
+ row['n_active'] = active.sum().item()
+ c1_rows.append(row)
+
+ print(f" {method}: gamma_active={np.nanmean([r['gamma_active'] for r in c1_rows if r['method']==method]):.4f} "
+ f"gamma_inactive={np.nanmean([r['gamma_inactive'] for r in c1_rows if r['method']==method]):.4f}", flush=True)
+
+ out_c1 = os.path.join(args.output_dir, 'C1_active_vs_inactive_cosine.csv')
+ with open(out_c1, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['method','layer','gamma_all','gamma_active','gamma_inactive','n_active'])
+ w.writeheader(); w.writerows(c1_rows)
+ print(f"C1: {len(c1_rows)} rows -> {out_c1}", flush=True)
+
+ # ===== C2: Gradient energy concentration =====
+ print(f"\n{'='*60}\nC2: Gradient energy concentration\n{'='*60}", flush=True)
+ c2_rows = []
+ ks = [1, 5, 10, 25, 50, 75, 90, 95, 99]
+ for method in methods_a2:
+ ckpt = f'results/confirmatory/checkpoints_A2/{method}_s42.pt'
+ if not os.path.exists(ckpt): continue
+ torch.manual_seed(42)
+ model = ResidualMLP(3072, d, 10, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device))
+ bp, _, _ = get_bp_grads_and_info(model, x_eval, y_eval, device)
+
+ for l in range(L):
+ norms = bp[l].norm(dim=-1)
+ energy = norms ** 2
+ total_energy = energy.sum().item()
+ sorted_energy, _ = energy.sort(descending=True)
+ cumsum = sorted_energy.cumsum(0)
+ for k in ks:
+ n_top = max(1, int(batch * k / 100))
+ frac = cumsum[n_top - 1].item() / (total_energy + 1e-20)
+ c2_rows.append({'method': method, 'layer': l, 'top_k_pct': k, 'energy_fraction': frac})
+
+ # Summary
+ all_e = torch.stack([bp[l].norm(-1)**2 for l in range(L)]).flatten()
+ se, _ = all_e.sort(descending=True)
+ cs = se.cumsum(0)
+ te = all_e.sum()
+ top1 = cs[max(1, int(len(all_e)*0.01))-1].item() / (te.item()+1e-20)
+ top10 = cs[max(1, int(len(all_e)*0.10))-1].item() / (te.item()+1e-20)
+ print(f" {method}: top1%={top1:.4f}, top10%={top10:.4f}", flush=True)
+
+ out_c2 = os.path.join(args.output_dir, 'C2_energy_concentration.csv')
+ with open(out_c2, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['method','layer','top_k_pct','energy_fraction'])
+ w.writeheader(); w.writerows(c2_rows)
+ print(f"C2: {len(c2_rows)} rows -> {out_c2}", flush=True)
+
+ print("\nALL DONE", flush=True)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/confirmatory')
+ args = p.parse_args()
+ run_analysis(args)
+
+if __name__ == '__main__':
+ main()