""" Clean BP gradient check — run in independent Python process per method. Usage: python clean_gradient_check.py --method bp --seed 42 --gpu 1 """ import os, sys, json, argparse, numpy as np, torch, torch.nn.functional as F from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP import torchvision, torchvision.transforms as transforms def main(): p = argparse.ArgumentParser() p.add_argument('--method', type=str, required=True) p.add_argument('--seed', type=int, default=42) p.add_argument('--gpu', type=int, default=1) p.add_argument('--output_dir', type=str, default='results/confirmatory/clean_grads') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device(f'cuda:{args.gpu}') # 1. Load eval data (256 samples, first batch, no shuffle) tv = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), 256, False, num_workers=0) # num_workers=0 for determinism for x, y in tel: x = x.view(x.size(0), -1).to(device) y = y.to(device) break batch = x.size(0) print(f"[{args.method} s={args.seed}] Batch: {batch}, y[:5]={y[:5].tolist()}", flush=True) # 2. Create model from scratch, load checkpoint (strict=True) L, d, C = 4, 256, 10 ckpt_path = f'results/confirmatory/checkpoints_A2/{args.method}_s{args.seed}.pt' assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}" model = ResidualMLP(3072, d, C, L).to(device) sd = torch.load(ckpt_path, map_location=device) model.load_state_dict(sd, strict=True) model.eval() # Verify: print first param norm and checkpoint hash first_param = list(model.parameters())[0] print(f" First param norm: {first_param.norm().item():.6f}", flush=True) print(f" Checkpoint: {ckpt_path}", flush=True) # 3. Method A: manual forward + autograd.grad h0 = model.embed(x.detach()) hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) lo_a = model.out_head(model.out_ln(hs[-1])) loss_a = F.cross_entropy(lo_a, y) acc_a = (lo_a.argmax(1) == y).float().mean().item() gs_a = torch.autograd.grad(loss_a, hs) print(f" Method A (manual+autograd.grad): loss={loss_a.item():.6f} acc={acc_a:.4f}", flush=True) for l in range(L): n = gs_a[l].norm(dim=-1) print(f" layer {l}: mean_norm={n.mean():.2e} median={n.median():.2e} " f"max={n.max():.2e} s(1e-6)={(n>1e-6).float().mean():.4f}", flush=True) # 4. Method B: retain_grad + backward model.zero_grad() for param in model.parameters(): param.requires_grad_(True) lo_b, hi_b = model(x, return_hidden=True) for l in range(L + 1): hi_b[l].retain_grad() loss_b = F.cross_entropy(lo_b, y) acc_b = (lo_b.argmax(1) == y).float().mean().item() loss_b.backward() print(f" Method B (retain_grad+backward): loss={loss_b.item():.6f} acc={acc_b:.4f}", flush=True) for l in range(L): if hi_b[l].grad is not None: n = hi_b[l].grad.norm(dim=-1) print(f" layer {l}: mean_norm={n.mean():.2e} median={n.median():.2e} " f"max={n.max():.2e} s(1e-6)={(n>1e-6).float().mean():.4f}", flush=True) else: print(f" layer {l}: grad is None!", flush=True) # 5. Method C: full model backward (no detach) model.zero_grad() lo_c = model(x) loss_c = F.cross_entropy(lo_c, y) loss_c.backward() # Get embedding gradient as proxy embed_grad_norm = model.embed.weight.grad.norm().item() if model.embed.weight.grad is not None else 0 print(f" Method C (full backward): loss={loss_c.item():.6f} embed_grad_norm={embed_grad_norm:.2e}", flush=True) # 6. Save results result = { 'method': args.method, 'seed': args.seed, 'batch_size': batch, 'y_first5': y[:5].tolist(), 'first_param_norm': first_param.norm().item(), 'method_A': { 'loss': loss_a.item(), 'acc': acc_a, 'per_layer': [{ 'mean_norm': gs_a[l].norm(-1).mean().item(), 'median_norm': gs_a[l].norm(-1).median().item(), 'max_norm': gs_a[l].norm(-1).max().item(), 's_1e6': (gs_a[l].norm(-1) > 1e-6).float().mean().item(), } for l in range(L)] }, 'method_B': { 'loss': loss_b.item(), 'acc': acc_b, 'per_layer': [{ 'mean_norm': hi_b[l].grad.norm(-1).mean().item() if hi_b[l].grad is not None else None, 'median_norm': hi_b[l].grad.norm(-1).median().item() if hi_b[l].grad is not None else None, 'max_norm': hi_b[l].grad.norm(-1).max().item() if hi_b[l].grad is not None else None, 's_1e6': (hi_b[l].grad.norm(-1) > 1e-6).float().mean().item() if hi_b[l].grad is not None else None, } for l in range(L)] }, 'method_C_embed_grad_norm': embed_grad_norm, } out = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json') with open(out, 'w') as f: json.dump(result, f, indent=2) print(f" Saved to {out}", flush=True) if __name__ == '__main__': main()