diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 10:54:40 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 10:54:40 -0500 |
| commit | ff91444879a7035c88b2c3c48859f36fb560c660 (patch) | |
| tree | bbbf458d07a4c17c633fabb3a89c7fab35a20d10 /experiments | |
| parent | 6315e18de1b8640ddf4a818c767f3fc14cc5001e (diff) | |
Add confirmatory supplement: T1-T4 from checkpoints (no retraining)
WARNING: All methods (including BP) show near-zero BP hidden gradients (~1e-12-1e-14)
when computed via manual forward with detached hidden states. This is inconsistent with
the earlier first-priority analysis which showed BP at 2.86e-04. Investigation needed.
T1: 40 rows (4 methods × 10 seeds) - full metrics
T2: 800 rows (support sparsity, 5 thresholds × 4 methods × 10 seeds × 4 layers)
T3: 48 rows (gradient norm distributions, 3 seeds × 4 methods × 4 layers)
T4: 100 rows (active-subset Gamma, 5 thresholds × 2 methods × 10 seeds)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/confirmatory_supplement.py | 273 |
1 files changed, 273 insertions, 0 deletions
diff --git a/experiments/confirmatory_supplement.py b/experiments/confirmatory_supplement.py new file mode 100644 index 0000000..89cd9cc --- /dev/null +++ b/experiments/confirmatory_supplement.py @@ -0,0 +1,273 @@ +""" +Confirmatory supplement: all from existing checkpoints, no retraining. +Task 1: CIFAR full metrics (Gamma_raw, Gamma_filtered, rho, acc, naive_StateErr, StateErr) +Task 2: Support sparsity (5 thresholds × 4 methods × 10 seeds × 4 layers) +Task 3: Per-layer gradient norm distribution (percentiles) +Task 4: Active-subset Gamma for BP and DFA +""" +import os, sys, csv, json, argparse, numpy as np, torch, torch.nn.functional as F +from torch.utils.data import DataLoader +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +from metrics.credit_metrics import perturbation_correlation +import torchvision, torchvision.transforms as transforms + + +def get_test_batch(device, n_batches=4): + tv = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), + 256, False, num_workers=4) + xs, ys = [], [] + for x, y in tel: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if len(xs) >= n_batches: break + return torch.cat(xs).to(device), torch.cat(ys).to(device) + + +def get_bp_grads(model, x, y, device): + model.eval(); L = model.num_blocks + h0 = model.embed(x.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) + lo = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(lo, y) + gs = torch.autograd.grad(loss, hs) + return {l: gs[l].detach() for l in range(L)}, lo.detach(), F.cross_entropy(lo, y, reduction='none').detach() + + +def get_dfa_Bs(seed, d, C, L, device): + """Regenerate DFA Bs with exact same seed sequence as training.""" + torch.manual_seed(seed) + _ = ResidualMLP(3072, d, C, L) # consume same random state as model init + return [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + + +def run(args): + device = torch.device(f'cuda:{args.gpu}') + os.makedirs(args.output_dir, exist_ok=True) + x_eval, y_eval = get_test_batch(device) + batch = x_eval.size(0) + print(f"Eval: {batch} samples", flush=True) + + L, d, C = 4, 256, 10 + seeds = [42, 123, 456, 789, 1024, 2048, 3000, 4000, 5000, 6000] + methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4] + + # ===== Task 1: Full metrics ===== + print(f"\n{'='*60}\nTask 1: Full metrics\n{'='*60}", flush=True) + t1_rows = [] + for method in methods: + for seed in seeds: + ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' + if not os.path.exists(ckpt): + print(f" SKIP {ckpt}", flush=True); continue + torch.manual_seed(seed) + model = ResidualMLP(3072, d, C, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device)) + bp, lo, lps = get_bp_grads(model, x_eval, y_eval, device) + + # Accuracy + acc = (lo.argmax(1) == y_eval).float().mean().item() + + # Naive StateErr + with torch.no_grad(): + _, hi = model(x_eval, return_hidden=True) + h_mid = hi[L//2]; h_L = hi[-1] + nse = ((h_mid - h_L).norm(-1) / h_L.norm(-1).clamp(min=1e-8)).mean().item() + + # DFA Bs for Gamma computation + dfa_Bs = get_dfa_Bs(seed, d, C, L, device) + + # e_T + with torch.no_grad(): + logits = model(x_eval) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1 + + # Per-layer metrics + gamma_raw_list, gamma_filt_list, rho_list = [], [], [] + for l in range(L): + g = bp[l]; norms = g.norm(-1); mask = norms > 1e-6 + a_dfa = (e_T @ dfa_Bs[l].T).detach() + h_l = hi[l].detach() + + # Gamma raw & filtered (DFA vs BP) + if method == 'bp': + gamma_raw_list.append(1.0) + gamma_filt_list.append(1.0) + else: + cos = F.cosine_similarity(a_dfa, g, dim=-1) + gamma_raw_list.append(cos.mean().item()) + gamma_filt_list.append(cos[mask].mean().item() if mask.sum() > 0 else float('nan')) + + # Rho (perturbation correlation) + # Use method-appropriate credit for rho + if method == 'bp': + a_l = g + else: + a_l = a_dfa # Use DFA credit for all non-BP (closest available) + + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c = h + for i in range(sl, L): c = c + model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)), y_eval, reduction='none') + return f + rho = perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16) + rho_list.append(rho) + + row = { + 'method': method, 'seed': seed, 'acc': acc, + 'naive_StateErr': nse, + 'Gamma_raw': np.mean(gamma_raw_list), + 'Gamma_filtered': np.nanmean(gamma_filt_list), + 'rho': np.mean(rho_list), + 'mean_bp_grad_norm': np.mean([bp[l].norm(-1).mean().item() for l in range(L)]), + } + t1_rows.append(row) + if seed in [42, 123]: + print(f" {method} s={seed}: acc={acc:.4f} Gr={row['Gamma_raw']:.4f} " + f"Gf={row['Gamma_filtered']:.4f} rho={row['rho']:.4f} nse={nse:.4f}", flush=True) + + out1 = os.path.join(args.output_dir, 'T1_cifar_full_metrics.csv') + with open(out1, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','seed','acc','naive_StateErr','Gamma_raw','Gamma_filtered','rho','mean_bp_grad_norm']) + w.writeheader(); w.writerows(t1_rows) + print(f"Task 1: {len(t1_rows)} rows -> {out1}", flush=True) + + # ===== Task 2: Support sparsity ===== + print(f"\n{'='*60}\nTask 2: Support sparsity\n{'='*60}", flush=True) + t2_rows = [] + for method in methods: + for seed in seeds: + ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' + if not os.path.exists(ckpt): continue + torch.manual_seed(seed) + model = ResidualMLP(3072, d, C, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device)) + bp, _, _ = get_bp_grads(model, x_eval, y_eval, device) + for l in range(L): + norms = bp[l].norm(-1) + for tau in thresholds: + t2_rows.append({ + 'method': method, 'seed': seed, 'layer': l, + 'threshold': tau, 'support_fraction': (norms > tau).float().mean().item(), + 'mean_norm': norms.mean().item(), 'median_norm': norms.median().item() + }) + print(f" {method}: done ({len(seeds)} seeds)", flush=True) + + out2 = os.path.join(args.output_dir, 'T2_support_sparsity.csv') + with open(out2, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','seed','layer','threshold','support_fraction','mean_norm','median_norm']) + w.writeheader(); w.writerows(t2_rows) + print(f"Task 2: {len(t2_rows)} rows -> {out2}", flush=True) + + # ===== Task 3: Gradient norm distribution ===== + print(f"\n{'='*60}\nTask 3: Gradient norm distribution\n{'='*60}", flush=True) + t3_rows = [] + percentiles = [1, 5, 10, 25, 50, 75, 90, 95, 99] + for method in methods: + for seed in seeds[:3]: # 3 seeds for distributions + ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' + if not os.path.exists(ckpt): continue + torch.manual_seed(seed) + model = ResidualMLP(3072, d, C, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device)) + bp, _, _ = get_bp_grads(model, x_eval, y_eval, device) + for l in range(L): + norms = bp[l].norm(-1).cpu().numpy() + log_norms = np.log10(norms.clip(min=1e-20)) + row = {'method': method, 'seed': seed, 'layer': l, + 'mean': float(norms.mean()), 'std': float(norms.std()), + 'mean_log10': float(log_norms.mean()), 'std_log10': float(log_norms.std())} + for p in percentiles: + row[f'p{p}'] = float(np.percentile(norms, p)) + row[f'p{p}_log10'] = float(np.percentile(log_norms, p)) + t3_rows.append(row) + print(f" {method}: done", flush=True) + + out3 = os.path.join(args.output_dir, 'T3_grad_norm_distribution.csv') + fields3 = ['method','seed','layer','mean','std','mean_log10','std_log10'] + \ + [f'p{p}' for p in percentiles] + [f'p{p}_log10' for p in percentiles] + with open(out3, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=fields3); w.writeheader(); w.writerows(t3_rows) + print(f"Task 3: {len(t3_rows)} rows -> {out3}", flush=True) + + # ===== Task 4: Active-subset Gamma ===== + print(f"\n{'='*60}\nTask 4: Active-subset Gamma\n{'='*60}", flush=True) + t4_rows = [] + for tau in thresholds: + for method in ['bp', 'dfa']: + for seed in seeds: + ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt' + if not os.path.exists(ckpt): continue + torch.manual_seed(seed) + model = ResidualMLP(3072, d, C, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device)) + bp, lo, _ = get_bp_grads(model, x_eval, y_eval, device) + dfa_Bs = get_dfa_Bs(seed, d, C, L, device) + with torch.no_grad(): + logits = model(x_eval) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1 + + gamma_active_list, gamma_energy_list, n_active_list = [], [], [] + for l in range(L): + g = bp[l]; norms = g.norm(-1); mask = norms > tau + if method == 'bp': + cos = torch.ones(batch, device=device) + else: + a_dfa = (e_T @ dfa_Bs[l].T).detach() + cos = F.cosine_similarity(a_dfa, g, dim=-1) + + # Active-subset Gamma + if mask.sum() > 0: + gamma_active_list.append(cos[mask].mean().item()) + else: + gamma_active_list.append(float('nan')) + + # Energy-weighted Gamma + weights = norms ** 2 + if weights.sum() > 0: + gamma_energy_list.append((cos * weights).sum().item() / (weights.sum().item() + 1e-20)) + else: + gamma_energy_list.append(float('nan')) + + n_active_list.append(mask.sum().item()) + + t4_rows.append({ + 'method': method, 'seed': seed, 'threshold': tau, + 'Gamma_active': np.nanmean(gamma_active_list), + 'Gamma_energy_weighted': np.nanmean(gamma_energy_list), + 'mean_n_active': np.mean(n_active_list), + 'pct_active': np.mean(n_active_list) / batch * 100, + }) + + # Summary for this threshold + for m in ['bp', 'dfa']: + vals = [r for r in t4_rows if r['method']==m and r['threshold']==tau] + if vals: + ga = np.nanmean([r['Gamma_active'] for r in vals]) + ge = np.nanmean([r['Gamma_energy_weighted'] for r in vals]) + pa = np.mean([r['pct_active'] for r in vals]) + print(f" tau={tau:.0e} {m}: Gamma_active={ga:.4f} Gamma_energy={ge:.4f} pct_active={pa:.1f}%", flush=True) + + out4 = os.path.join(args.output_dir, 'T4_active_subset_gamma.csv') + with open(out4, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','seed','threshold','Gamma_active','Gamma_energy_weighted','mean_n_active','pct_active']) + w.writeheader(); w.writerows(t4_rows) + print(f"Task 4: {len(t4_rows)} rows -> {out4}", flush=True) + + print("\nALL TASKS DONE", flush=True) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/confirmatory') + args = p.parse_args() + run(args) + +if __name__ == '__main__': + main() |
