summaryrefslogtreecommitdiff
path: root/experiments/clean_sparsity_persample.py
blob: e0f7e4d1bbbdf78b4f2c27ad6efdbc1f2d897361 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Per-sample gradient stats for CIFAR. One method+seed per invocation.
Outputs CSV: each row = one sample × one layer.
Usage: python clean_sparsity_persample.py --method bp --seed 42 --gpu 0
"""
import os, sys, csv, argparse, numpy as np, torch, torch.nn.functional as F
from torch.utils.data import DataLoader
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP
import torchvision, torchvision.transforms as transforms

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--method', type=str, required=True)
    p.add_argument('--seed', type=int, required=True)
    p.add_argument('--gpu', type=int, default=0)
    p.add_argument('--output_dir', type=str, default='results/confirmatory/persample')
    args = p.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    device = torch.device(f'cuda:{args.gpu}')

    tv = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
    tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv),
                     256, False, num_workers=0)
    for x, y in tel: x = x.view(x.size(0),-1).to(device); y = y.to(device); break
    batch = x.size(0)

    L, d = 4, 256
    model = ResidualMLP(3072, d, 10, L).to(device)
    model.load_state_dict(torch.load(
        f'results/confirmatory/checkpoints_A2/{args.method}_s{args.seed}.pt',
        map_location=device), strict=True)
    model.eval()

    h0 = model.embed(x.detach())
    hs = [h0.clone().requires_grad_(True)]
    for b in model.blocks: hs.append(hs[-1] + b(hs[-1]))
    lo = model.out_head(model.out_ln(hs[-1]))
    loss = F.cross_entropy(lo, y)
    gs = torch.autograd.grad(loss, hs)

    rows = []
    for l in range(L):
        g = gs[l].detach()  # (batch, d)
        n2 = g.norm(dim=-1)
        ninf = g.abs().max(dim=-1).values
        n4 = (g.abs()**4).sum(-1)**(1/4)
        n1 = g.abs().sum(-1)
        r_inf = ninf / n2.clamp(min=1e-30)
        pr = (n2**4 / (n4**4).clamp(min=1e-60)) / d
        hoyer_num = (n1 / (n2 * d**0.5).clamp(min=1e-30))**2
        gsq = g**2; te = gsq.sum(-1, keepdim=True).clamp(min=1e-60)
        ssq, _ = gsq.sort(dim=-1, descending=True); cs = ssq.cumsum(-1)
        k1 = max(1, int(d*0.01))-1; k5 = max(1, int(d*0.05))-1
        topk1 = (cs[:, k1:k1+1] / te).squeeze(-1)
        topk5 = (cs[:, k5:k5+1] / te).squeeze(-1)

        for i in range(batch):
            rows.append({
                'method': args.method, 'seed': args.seed, 'layer': l, 'sample_id': i,
                'grad_norm': n2[i].item(),
                'log10_grad_norm': np.log10(max(n2[i].item(), 1e-30)),
                'r_inf': r_inf[i].item(),
                'pr': pr[i].item(),
                'hoyer': hoyer_num[i].item(),
                'topk1': topk1[i].item(),
                'topk5': topk5[i].item(),
            })

    out = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.csv')
    with open(out, 'w', newline='') as f:
        w = csv.DictWriter(f, fieldnames=['method','seed','layer','sample_id','grad_norm','log10_grad_norm','r_inf','pr','hoyer','topk1','topk5'])
        w.writeheader(); w.writerows(rows)
    print(f"[{args.method} s={args.seed}] {len(rows)} rows -> {out}", flush=True)

if __name__ == '__main__':
    main()