summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 10:54:40 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 10:54:40 -0500
commitff91444879a7035c88b2c3c48859f36fb560c660 (patch)
treebbbf458d07a4c17c633fabb3a89c7fab35a20d10 /experiments
parent6315e18de1b8640ddf4a818c767f3fc14cc5001e (diff)
Add confirmatory supplement: T1-T4 from checkpoints (no retraining)
WARNING: All methods (including BP) show near-zero BP hidden gradients (~1e-12-1e-14) when computed via manual forward with detached hidden states. This is inconsistent with the earlier first-priority analysis which showed BP at 2.86e-04. Investigation needed. T1: 40 rows (4 methods × 10 seeds) - full metrics T2: 800 rows (support sparsity, 5 thresholds × 4 methods × 10 seeds × 4 layers) T3: 48 rows (gradient norm distributions, 3 seeds × 4 methods × 4 layers) T4: 100 rows (active-subset Gamma, 5 thresholds × 2 methods × 10 seeds) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/confirmatory_supplement.py273
1 files changed, 273 insertions, 0 deletions
diff --git a/experiments/confirmatory_supplement.py b/experiments/confirmatory_supplement.py
new file mode 100644
index 0000000..89cd9cc
--- /dev/null
+++ b/experiments/confirmatory_supplement.py
@@ -0,0 +1,273 @@
+"""
+Confirmatory supplement: all from existing checkpoints, no retraining.
+Task 1: CIFAR full metrics (Gamma_raw, Gamma_filtered, rho, acc, naive_StateErr, StateErr)
+Task 2: Support sparsity (5 thresholds × 4 methods × 10 seeds × 4 layers)
+Task 3: Per-layer gradient norm distribution (percentiles)
+Task 4: Active-subset Gamma for BP and DFA
+"""
+import os, sys, csv, json, argparse, numpy as np, torch, torch.nn.functional as F
+from torch.utils.data import DataLoader
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+from metrics.credit_metrics import perturbation_correlation
+import torchvision, torchvision.transforms as transforms
+
+
+def get_test_batch(device, n_batches=4):
+ tv = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv),
+ 256, False, num_workers=4)
+ xs, ys = [], []
+ for x, y in tel:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if len(xs) >= n_batches: break
+ return torch.cat(xs).to(device), torch.cat(ys).to(device)
+
+
+def get_bp_grads(model, x, y, device):
+ model.eval(); L = model.num_blocks
+ h0 = model.embed(x.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks: hs.append(hs[-1] + b(hs[-1]))
+ lo = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(lo, y)
+ gs = torch.autograd.grad(loss, hs)
+ return {l: gs[l].detach() for l in range(L)}, lo.detach(), F.cross_entropy(lo, y, reduction='none').detach()
+
+
+def get_dfa_Bs(seed, d, C, L, device):
+ """Regenerate DFA Bs with exact same seed sequence as training."""
+ torch.manual_seed(seed)
+ _ = ResidualMLP(3072, d, C, L) # consume same random state as model init
+ return [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+
+
+def run(args):
+ device = torch.device(f'cuda:{args.gpu}')
+ os.makedirs(args.output_dir, exist_ok=True)
+ x_eval, y_eval = get_test_batch(device)
+ batch = x_eval.size(0)
+ print(f"Eval: {batch} samples", flush=True)
+
+ L, d, C = 4, 256, 10
+ seeds = [42, 123, 456, 789, 1024, 2048, 3000, 4000, 5000, 6000]
+ methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+ thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4]
+
+ # ===== Task 1: Full metrics =====
+ print(f"\n{'='*60}\nTask 1: Full metrics\n{'='*60}", flush=True)
+ t1_rows = []
+ for method in methods:
+ for seed in seeds:
+ ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt'
+ if not os.path.exists(ckpt):
+ print(f" SKIP {ckpt}", flush=True); continue
+ torch.manual_seed(seed)
+ model = ResidualMLP(3072, d, C, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device))
+ bp, lo, lps = get_bp_grads(model, x_eval, y_eval, device)
+
+ # Accuracy
+ acc = (lo.argmax(1) == y_eval).float().mean().item()
+
+ # Naive StateErr
+ with torch.no_grad():
+ _, hi = model(x_eval, return_hidden=True)
+ h_mid = hi[L//2]; h_L = hi[-1]
+ nse = ((h_mid - h_L).norm(-1) / h_L.norm(-1).clamp(min=1e-8)).mean().item()
+
+ # DFA Bs for Gamma computation
+ dfa_Bs = get_dfa_Bs(seed, d, C, L, device)
+
+ # e_T
+ with torch.no_grad():
+ logits = model(x_eval)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1
+
+ # Per-layer metrics
+ gamma_raw_list, gamma_filt_list, rho_list = [], [], []
+ for l in range(L):
+ g = bp[l]; norms = g.norm(-1); mask = norms > 1e-6
+ a_dfa = (e_T @ dfa_Bs[l].T).detach()
+ h_l = hi[l].detach()
+
+ # Gamma raw & filtered (DFA vs BP)
+ if method == 'bp':
+ gamma_raw_list.append(1.0)
+ gamma_filt_list.append(1.0)
+ else:
+ cos = F.cosine_similarity(a_dfa, g, dim=-1)
+ gamma_raw_list.append(cos.mean().item())
+ gamma_filt_list.append(cos[mask].mean().item() if mask.sum() > 0 else float('nan'))
+
+ # Rho (perturbation correlation)
+ # Use method-appropriate credit for rho
+ if method == 'bp':
+ a_l = g
+ else:
+ a_l = a_dfa # Use DFA credit for all non-BP (closest available)
+
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c = h
+ for i in range(sl, L): c = c + model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)), y_eval, reduction='none')
+ return f
+ rho = perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16)
+ rho_list.append(rho)
+
+ row = {
+ 'method': method, 'seed': seed, 'acc': acc,
+ 'naive_StateErr': nse,
+ 'Gamma_raw': np.mean(gamma_raw_list),
+ 'Gamma_filtered': np.nanmean(gamma_filt_list),
+ 'rho': np.mean(rho_list),
+ 'mean_bp_grad_norm': np.mean([bp[l].norm(-1).mean().item() for l in range(L)]),
+ }
+ t1_rows.append(row)
+ if seed in [42, 123]:
+ print(f" {method} s={seed}: acc={acc:.4f} Gr={row['Gamma_raw']:.4f} "
+ f"Gf={row['Gamma_filtered']:.4f} rho={row['rho']:.4f} nse={nse:.4f}", flush=True)
+
+ out1 = os.path.join(args.output_dir, 'T1_cifar_full_metrics.csv')
+ with open(out1, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['method','seed','acc','naive_StateErr','Gamma_raw','Gamma_filtered','rho','mean_bp_grad_norm'])
+ w.writeheader(); w.writerows(t1_rows)
+ print(f"Task 1: {len(t1_rows)} rows -> {out1}", flush=True)
+
+ # ===== Task 2: Support sparsity =====
+ print(f"\n{'='*60}\nTask 2: Support sparsity\n{'='*60}", flush=True)
+ t2_rows = []
+ for method in methods:
+ for seed in seeds:
+ ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt'
+ if not os.path.exists(ckpt): continue
+ torch.manual_seed(seed)
+ model = ResidualMLP(3072, d, C, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device))
+ bp, _, _ = get_bp_grads(model, x_eval, y_eval, device)
+ for l in range(L):
+ norms = bp[l].norm(-1)
+ for tau in thresholds:
+ t2_rows.append({
+ 'method': method, 'seed': seed, 'layer': l,
+ 'threshold': tau, 'support_fraction': (norms > tau).float().mean().item(),
+ 'mean_norm': norms.mean().item(), 'median_norm': norms.median().item()
+ })
+ print(f" {method}: done ({len(seeds)} seeds)", flush=True)
+
+ out2 = os.path.join(args.output_dir, 'T2_support_sparsity.csv')
+ with open(out2, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['method','seed','layer','threshold','support_fraction','mean_norm','median_norm'])
+ w.writeheader(); w.writerows(t2_rows)
+ print(f"Task 2: {len(t2_rows)} rows -> {out2}", flush=True)
+
+ # ===== Task 3: Gradient norm distribution =====
+ print(f"\n{'='*60}\nTask 3: Gradient norm distribution\n{'='*60}", flush=True)
+ t3_rows = []
+ percentiles = [1, 5, 10, 25, 50, 75, 90, 95, 99]
+ for method in methods:
+ for seed in seeds[:3]: # 3 seeds for distributions
+ ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt'
+ if not os.path.exists(ckpt): continue
+ torch.manual_seed(seed)
+ model = ResidualMLP(3072, d, C, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device))
+ bp, _, _ = get_bp_grads(model, x_eval, y_eval, device)
+ for l in range(L):
+ norms = bp[l].norm(-1).cpu().numpy()
+ log_norms = np.log10(norms.clip(min=1e-20))
+ row = {'method': method, 'seed': seed, 'layer': l,
+ 'mean': float(norms.mean()), 'std': float(norms.std()),
+ 'mean_log10': float(log_norms.mean()), 'std_log10': float(log_norms.std())}
+ for p in percentiles:
+ row[f'p{p}'] = float(np.percentile(norms, p))
+ row[f'p{p}_log10'] = float(np.percentile(log_norms, p))
+ t3_rows.append(row)
+ print(f" {method}: done", flush=True)
+
+ out3 = os.path.join(args.output_dir, 'T3_grad_norm_distribution.csv')
+ fields3 = ['method','seed','layer','mean','std','mean_log10','std_log10'] + \
+ [f'p{p}' for p in percentiles] + [f'p{p}_log10' for p in percentiles]
+ with open(out3, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=fields3); w.writeheader(); w.writerows(t3_rows)
+ print(f"Task 3: {len(t3_rows)} rows -> {out3}", flush=True)
+
+ # ===== Task 4: Active-subset Gamma =====
+ print(f"\n{'='*60}\nTask 4: Active-subset Gamma\n{'='*60}", flush=True)
+ t4_rows = []
+ for tau in thresholds:
+ for method in ['bp', 'dfa']:
+ for seed in seeds:
+ ckpt = f'results/confirmatory/checkpoints_A2/{method}_s{seed}.pt'
+ if not os.path.exists(ckpt): continue
+ torch.manual_seed(seed)
+ model = ResidualMLP(3072, d, C, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device))
+ bp, lo, _ = get_bp_grads(model, x_eval, y_eval, device)
+ dfa_Bs = get_dfa_Bs(seed, d, C, L, device)
+ with torch.no_grad():
+ logits = model(x_eval)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y_eval] -= 1
+
+ gamma_active_list, gamma_energy_list, n_active_list = [], [], []
+ for l in range(L):
+ g = bp[l]; norms = g.norm(-1); mask = norms > tau
+ if method == 'bp':
+ cos = torch.ones(batch, device=device)
+ else:
+ a_dfa = (e_T @ dfa_Bs[l].T).detach()
+ cos = F.cosine_similarity(a_dfa, g, dim=-1)
+
+ # Active-subset Gamma
+ if mask.sum() > 0:
+ gamma_active_list.append(cos[mask].mean().item())
+ else:
+ gamma_active_list.append(float('nan'))
+
+ # Energy-weighted Gamma
+ weights = norms ** 2
+ if weights.sum() > 0:
+ gamma_energy_list.append((cos * weights).sum().item() / (weights.sum().item() + 1e-20))
+ else:
+ gamma_energy_list.append(float('nan'))
+
+ n_active_list.append(mask.sum().item())
+
+ t4_rows.append({
+ 'method': method, 'seed': seed, 'threshold': tau,
+ 'Gamma_active': np.nanmean(gamma_active_list),
+ 'Gamma_energy_weighted': np.nanmean(gamma_energy_list),
+ 'mean_n_active': np.mean(n_active_list),
+ 'pct_active': np.mean(n_active_list) / batch * 100,
+ })
+
+ # Summary for this threshold
+ for m in ['bp', 'dfa']:
+ vals = [r for r in t4_rows if r['method']==m and r['threshold']==tau]
+ if vals:
+ ga = np.nanmean([r['Gamma_active'] for r in vals])
+ ge = np.nanmean([r['Gamma_energy_weighted'] for r in vals])
+ pa = np.mean([r['pct_active'] for r in vals])
+ print(f" tau={tau:.0e} {m}: Gamma_active={ga:.4f} Gamma_energy={ge:.4f} pct_active={pa:.1f}%", flush=True)
+
+ out4 = os.path.join(args.output_dir, 'T4_active_subset_gamma.csv')
+ with open(out4, 'w', newline='') as f:
+ w = csv.DictWriter(f, fieldnames=['method','seed','threshold','Gamma_active','Gamma_energy_weighted','mean_n_active','pct_active'])
+ w.writeheader(); w.writerows(t4_rows)
+ print(f"Task 4: {len(t4_rows)} rows -> {out4}", flush=True)
+
+ print("\nALL TASKS DONE", flush=True)
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/confirmatory')
+ args = p.parse_args()
+ run(args)
+
+if __name__ == '__main__':
+ main()