summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 12:56:24 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 12:56:24 -0500
commitd5326053a2e9ce37dd61606aa37fa8f563481f44 (patch)
tree653f8bf3098d382a1162c09ce4983d9d1c50713e /experiments
parentcd80da41c620d7c8b17e36d3ed7ab7e6b582f191 (diff)
Add clean gradient check: independent Python process per method, GPU 1
Clean results (each method in fresh Python process): BP: mean_norm=2.58e-04, s(1e-6)=98% — CONFIRMED DFA: layer 0 = 2.86e-07 (1.2%), layers 1-3 ≈ 2.4e-09 (0%) SB: layer 0 = 6.13e-06 (86%), layers 1-3 ≈ 1e-09 (0%) CB: layer 0 = 6.33e-07 (18%), layers 1-3 ≈ 5e-10 (0%) Method A (autograd.grad) and Method B (retain_grad) give identical results. Previous 1e-12 results were caused by Python process state pollution in combined scripts. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/clean_gradient_check.py126
1 files changed, 126 insertions, 0 deletions
diff --git a/experiments/clean_gradient_check.py b/experiments/clean_gradient_check.py
new file mode 100644
index 0000000..4e96642
--- /dev/null
+++ b/experiments/clean_gradient_check.py
@@ -0,0 +1,126 @@
+"""
+Clean BP gradient check — run in independent Python process per method.
+Usage: python clean_gradient_check.py --method bp --seed 42 --gpu 1
+"""
+import os, sys, json, argparse, numpy as np, torch, torch.nn.functional as F
+from torch.utils.data import DataLoader
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+import torchvision, torchvision.transforms as transforms
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--method', type=str, required=True)
+ p.add_argument('--seed', type=int, default=42)
+ p.add_argument('--gpu', type=int, default=1)
+ p.add_argument('--output_dir', type=str, default='results/confirmatory/clean_grads')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}')
+
+ # 1. Load eval data (256 samples, first batch, no shuffle)
+ tv = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv),
+ 256, False, num_workers=0) # num_workers=0 for determinism
+ for x, y in tel:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ break
+ batch = x.size(0)
+ print(f"[{args.method} s={args.seed}] Batch: {batch}, y[:5]={y[:5].tolist()}", flush=True)
+
+ # 2. Create model from scratch, load checkpoint (strict=True)
+ L, d, C = 4, 256, 10
+ ckpt_path = f'results/confirmatory/checkpoints_A2/{args.method}_s{args.seed}.pt'
+ assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
+
+ model = ResidualMLP(3072, d, C, L).to(device)
+ sd = torch.load(ckpt_path, map_location=device)
+ model.load_state_dict(sd, strict=True)
+ model.eval()
+
+ # Verify: print first param norm and checkpoint hash
+ first_param = list(model.parameters())[0]
+ print(f" First param norm: {first_param.norm().item():.6f}", flush=True)
+ print(f" Checkpoint: {ckpt_path}", flush=True)
+
+ # 3. Method A: manual forward + autograd.grad
+ h0 = model.embed(x.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks:
+ hs.append(hs[-1] + b(hs[-1]))
+ lo_a = model.out_head(model.out_ln(hs[-1]))
+ loss_a = F.cross_entropy(lo_a, y)
+ acc_a = (lo_a.argmax(1) == y).float().mean().item()
+ gs_a = torch.autograd.grad(loss_a, hs)
+
+ print(f" Method A (manual+autograd.grad): loss={loss_a.item():.6f} acc={acc_a:.4f}", flush=True)
+ for l in range(L):
+ n = gs_a[l].norm(dim=-1)
+ print(f" layer {l}: mean_norm={n.mean():.2e} median={n.median():.2e} "
+ f"max={n.max():.2e} s(1e-6)={(n>1e-6).float().mean():.4f}", flush=True)
+
+ # 4. Method B: retain_grad + backward
+ model.zero_grad()
+ for param in model.parameters():
+ param.requires_grad_(True)
+ lo_b, hi_b = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hi_b[l].retain_grad()
+ loss_b = F.cross_entropy(lo_b, y)
+ acc_b = (lo_b.argmax(1) == y).float().mean().item()
+ loss_b.backward()
+
+ print(f" Method B (retain_grad+backward): loss={loss_b.item():.6f} acc={acc_b:.4f}", flush=True)
+ for l in range(L):
+ if hi_b[l].grad is not None:
+ n = hi_b[l].grad.norm(dim=-1)
+ print(f" layer {l}: mean_norm={n.mean():.2e} median={n.median():.2e} "
+ f"max={n.max():.2e} s(1e-6)={(n>1e-6).float().mean():.4f}", flush=True)
+ else:
+ print(f" layer {l}: grad is None!", flush=True)
+
+ # 5. Method C: full model backward (no detach)
+ model.zero_grad()
+ lo_c = model(x)
+ loss_c = F.cross_entropy(lo_c, y)
+ loss_c.backward()
+ # Get embedding gradient as proxy
+ embed_grad_norm = model.embed.weight.grad.norm().item() if model.embed.weight.grad is not None else 0
+ print(f" Method C (full backward): loss={loss_c.item():.6f} embed_grad_norm={embed_grad_norm:.2e}", flush=True)
+
+ # 6. Save results
+ result = {
+ 'method': args.method, 'seed': args.seed, 'batch_size': batch,
+ 'y_first5': y[:5].tolist(),
+ 'first_param_norm': first_param.norm().item(),
+ 'method_A': {
+ 'loss': loss_a.item(), 'acc': acc_a,
+ 'per_layer': [{
+ 'mean_norm': gs_a[l].norm(-1).mean().item(),
+ 'median_norm': gs_a[l].norm(-1).median().item(),
+ 'max_norm': gs_a[l].norm(-1).max().item(),
+ 's_1e6': (gs_a[l].norm(-1) > 1e-6).float().mean().item(),
+ } for l in range(L)]
+ },
+ 'method_B': {
+ 'loss': loss_b.item(), 'acc': acc_b,
+ 'per_layer': [{
+ 'mean_norm': hi_b[l].grad.norm(-1).mean().item() if hi_b[l].grad is not None else None,
+ 'median_norm': hi_b[l].grad.norm(-1).median().item() if hi_b[l].grad is not None else None,
+ 'max_norm': hi_b[l].grad.norm(-1).max().item() if hi_b[l].grad is not None else None,
+ 's_1e6': (hi_b[l].grad.norm(-1) > 1e-6).float().mean().item() if hi_b[l].grad is not None else None,
+ } for l in range(L)]
+ },
+ 'method_C_embed_grad_norm': embed_grad_norm,
+ }
+
+ out = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json')
+ with open(out, 'w') as f:
+ json.dump(result, f, indent=2)
+ print(f" Saved to {out}", flush=True)
+
+if __name__ == '__main__':
+ main()