summaryrefslogtreecommitdiff
path: root/experiments/clean_gradient_check.py
blob: 4e966423916f9a6957c03276e65a63f25a7b41d7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Clean BP gradient check — run in independent Python process per method.
Usage: python clean_gradient_check.py --method bp --seed 42 --gpu 1
"""
import os, sys, json, argparse, numpy as np, torch, torch.nn.functional as F
from torch.utils.data import DataLoader
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP
import torchvision, torchvision.transforms as transforms

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--method', type=str, required=True)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--gpu', type=int, default=1)
    p.add_argument('--output_dir', type=str, default='results/confirmatory/clean_grads')
    args = p.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device(f'cuda:{args.gpu}')

    # 1. Load eval data (256 samples, first batch, no shuffle)
    tv = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
    tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv),
                     256, False, num_workers=0)  # num_workers=0 for determinism
    for x, y in tel:
        x = x.view(x.size(0), -1).to(device)
        y = y.to(device)
        break
    batch = x.size(0)
    print(f"[{args.method} s={args.seed}] Batch: {batch}, y[:5]={y[:5].tolist()}", flush=True)

    # 2. Create model from scratch, load checkpoint (strict=True)
    L, d, C = 4, 256, 10
    ckpt_path = f'results/confirmatory/checkpoints_A2/{args.method}_s{args.seed}.pt'
    assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"

    model = ResidualMLP(3072, d, C, L).to(device)
    sd = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(sd, strict=True)
    model.eval()

    # Verify: print first param norm and checkpoint hash
    first_param = list(model.parameters())[0]
    print(f"  First param norm: {first_param.norm().item():.6f}", flush=True)
    print(f"  Checkpoint: {ckpt_path}", flush=True)

    # 3. Method A: manual forward + autograd.grad
    h0 = model.embed(x.detach())
    hs = [h0.clone().requires_grad_(True)]
    for b in model.blocks:
        hs.append(hs[-1] + b(hs[-1]))
    lo_a = model.out_head(model.out_ln(hs[-1]))
    loss_a = F.cross_entropy(lo_a, y)
    acc_a = (lo_a.argmax(1) == y).float().mean().item()
    gs_a = torch.autograd.grad(loss_a, hs)

    print(f"  Method A (manual+autograd.grad): loss={loss_a.item():.6f} acc={acc_a:.4f}", flush=True)
    for l in range(L):
        n = gs_a[l].norm(dim=-1)
        print(f"    layer {l}: mean_norm={n.mean():.2e} median={n.median():.2e} "
              f"max={n.max():.2e} s(1e-6)={(n>1e-6).float().mean():.4f}", flush=True)

    # 4. Method B: retain_grad + backward
    model.zero_grad()
    for param in model.parameters():
        param.requires_grad_(True)
    lo_b, hi_b = model(x, return_hidden=True)
    for l in range(L + 1):
        hi_b[l].retain_grad()
    loss_b = F.cross_entropy(lo_b, y)
    acc_b = (lo_b.argmax(1) == y).float().mean().item()
    loss_b.backward()

    print(f"  Method B (retain_grad+backward): loss={loss_b.item():.6f} acc={acc_b:.4f}", flush=True)
    for l in range(L):
        if hi_b[l].grad is not None:
            n = hi_b[l].grad.norm(dim=-1)
            print(f"    layer {l}: mean_norm={n.mean():.2e} median={n.median():.2e} "
                  f"max={n.max():.2e} s(1e-6)={(n>1e-6).float().mean():.4f}", flush=True)
        else:
            print(f"    layer {l}: grad is None!", flush=True)

    # 5. Method C: full model backward (no detach)
    model.zero_grad()
    lo_c = model(x)
    loss_c = F.cross_entropy(lo_c, y)
    loss_c.backward()
    # Get embedding gradient as proxy
    embed_grad_norm = model.embed.weight.grad.norm().item() if model.embed.weight.grad is not None else 0
    print(f"  Method C (full backward): loss={loss_c.item():.6f} embed_grad_norm={embed_grad_norm:.2e}", flush=True)

    # 6. Save results
    result = {
        'method': args.method, 'seed': args.seed, 'batch_size': batch,
        'y_first5': y[:5].tolist(),
        'first_param_norm': first_param.norm().item(),
        'method_A': {
            'loss': loss_a.item(), 'acc': acc_a,
            'per_layer': [{
                'mean_norm': gs_a[l].norm(-1).mean().item(),
                'median_norm': gs_a[l].norm(-1).median().item(),
                'max_norm': gs_a[l].norm(-1).max().item(),
                's_1e6': (gs_a[l].norm(-1) > 1e-6).float().mean().item(),
            } for l in range(L)]
        },
        'method_B': {
            'loss': loss_b.item(), 'acc': acc_b,
            'per_layer': [{
                'mean_norm': hi_b[l].grad.norm(-1).mean().item() if hi_b[l].grad is not None else None,
                'median_norm': hi_b[l].grad.norm(-1).median().item() if hi_b[l].grad is not None else None,
                'max_norm': hi_b[l].grad.norm(-1).max().item() if hi_b[l].grad is not None else None,
                's_1e6': (hi_b[l].grad.norm(-1) > 1e-6).float().mean().item() if hi_b[l].grad is not None else None,
            } for l in range(L)]
        },
        'method_C_embed_grad_norm': embed_grad_norm,
    }

    out = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json')
    with open(out, 'w') as f:
        json.dump(result, f, indent=2)
    print(f"  Saved to {out}", flush=True)

if __name__ == '__main__':
    main()