diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 13:22:33 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 13:22:33 -0500 |
| commit | 994be4d80271358e56c2125a55545fc567b0ab1d (patch) | |
| tree | d5f32329dbdfadfb0a0837dafe91c1c5db753f61 /experiments | |
| parent | 9f82fbbba3004a88f0c4bc6080f801fb65f0dd93 (diff) | |
Add clean_sparsity_persample.py: per-sample gradient stats
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/clean_sparsity_persample.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/experiments/clean_sparsity_persample.py b/experiments/clean_sparsity_persample.py new file mode 100644 index 0000000..e0f7e4d --- /dev/null +++ b/experiments/clean_sparsity_persample.py @@ -0,0 +1,78 @@ +""" +Per-sample gradient stats for CIFAR. One method+seed per invocation. +Outputs CSV: each row = one sample × one layer. +Usage: python clean_sparsity_persample.py --method bp --seed 42 --gpu 0 +""" +import os, sys, csv, argparse, numpy as np, torch, torch.nn.functional as F +from torch.utils.data import DataLoader +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +import torchvision, torchvision.transforms as transforms + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--method', type=str, required=True) + p.add_argument('--seed', type=int, required=True) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/confirmatory/persample') + args = p.parse_args() + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device(f'cuda:{args.gpu}') + + tv = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), + 256, False, num_workers=0) + for x, y in tel: x = x.view(x.size(0),-1).to(device); y = y.to(device); break + batch = x.size(0) + + L, d = 4, 256 + model = ResidualMLP(3072, d, 10, L).to(device) + model.load_state_dict(torch.load( + f'results/confirmatory/checkpoints_A2/{args.method}_s{args.seed}.pt', + map_location=device), strict=True) + model.eval() + + h0 = model.embed(x.detach()) + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) + lo = model.out_head(model.out_ln(hs[-1])) + loss = F.cross_entropy(lo, y) + gs = torch.autograd.grad(loss, hs) + + rows = [] + for l in range(L): + g = gs[l].detach() # (batch, d) + n2 = g.norm(dim=-1) + ninf = g.abs().max(dim=-1).values + n4 = (g.abs()**4).sum(-1)**(1/4) + n1 = g.abs().sum(-1) + r_inf = ninf / n2.clamp(min=1e-30) + pr = (n2**4 / (n4**4).clamp(min=1e-60)) / d + hoyer_num = (n1 / (n2 * d**0.5).clamp(min=1e-30))**2 + gsq = g**2; te = gsq.sum(-1, keepdim=True).clamp(min=1e-60) + ssq, _ = gsq.sort(dim=-1, descending=True); cs = ssq.cumsum(-1) + k1 = max(1, int(d*0.01))-1; k5 = max(1, int(d*0.05))-1 + topk1 = (cs[:, k1:k1+1] / te).squeeze(-1) + topk5 = (cs[:, k5:k5+1] / te).squeeze(-1) + + for i in range(batch): + rows.append({ + 'method': args.method, 'seed': args.seed, 'layer': l, 'sample_id': i, + 'grad_norm': n2[i].item(), + 'log10_grad_norm': np.log10(max(n2[i].item(), 1e-30)), + 'r_inf': r_inf[i].item(), + 'pr': pr[i].item(), + 'hoyer': hoyer_num[i].item(), + 'topk1': topk1[i].item(), + 'topk5': topk5[i].item(), + }) + + out = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.csv') + with open(out, 'w', newline='') as f: + w = csv.DictWriter(f, fieldnames=['method','seed','layer','sample_id','grad_norm','log10_grad_norm','r_inf','pr','hoyer','topk1','topk5']) + w.writeheader(); w.writerows(rows) + print(f"[{args.method} s={args.seed}] {len(rows)} rows -> {out}", flush=True) + +if __name__ == '__main__': + main() |
