summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 21:06:59 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 21:06:59 -0500
commitda057e5b827d33cc7ff1704a0da0fa9d3f6b7cb6 (patch)
tree32e91375a4d4c53b704d568615fd91d4b0a8f3aa /experiments
parentbd1ae4b38433358eb7ee2a7795a67ac53bfd43f3 (diff)
Add d512_sparsity.py: support sparsity for d=512 checkpoints
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/d512_sparsity.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/experiments/d512_sparsity.py b/experiments/d512_sparsity.py
new file mode 100644
index 0000000..70fb75e
--- /dev/null
+++ b/experiments/d512_sparsity.py
@@ -0,0 +1,67 @@
+"""d=512 sparsity analysis. One method+seed per invocation."""
+import os, sys, json, argparse, numpy as np, torch, torch.nn.functional as F
+from torch.utils.data import DataLoader
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+import torchvision, torchvision.transforms as transforms
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--method', type=str, required=True)
+ p.add_argument('--seed', type=int, required=True)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/confirmatory/d512_sparsity')
+ args = p.parse_args()
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}')
+ L, d = 4, 512
+ thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4]
+
+ tv = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv),
+ 256, False, num_workers=0)
+ for x, y in tel: x = x.view(x.size(0),-1).to(device); y = y.to(device); break
+
+ model = ResidualMLP(3072, d, 10, L).to(device)
+ model.load_state_dict(torch.load(
+ f'results/confirmatory/cifar_d512/{args.method}_s{args.seed}.pt',
+ map_location=device), strict=True)
+ model.eval()
+
+ h0 = model.embed(x.detach())
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks: hs.append(hs[-1] + b(hs[-1]))
+ lo = model.out_head(model.out_ln(hs[-1]))
+ loss = F.cross_entropy(lo, y)
+ acc = (lo.argmax(1) == y).float().mean().item()
+ gs = torch.autograd.grad(loss, hs)
+
+ result = {'method': args.method, 'seed': args.seed, 'd': d, 'acc': acc, 'per_layer': []}
+ for l in range(L):
+ g = gs[l].detach(); norms = g.norm(dim=-1)
+ ninf = g.abs().max(-1).values; n2 = norms.clamp(min=1e-30)
+ n4 = (g.abs()**4).sum(-1)**(1/4); n1 = g.abs().sum(-1)
+ r_inf = (ninf / n2); pr = (n2**4 / (n4**4).clamp(min=1e-60)) / d
+ hoyer = (n1 / (n2 * d**0.5).clamp(min=1e-30))**2
+ gsq = g**2; te = gsq.sum(-1, keepdim=True).clamp(min=1e-60)
+ ssq, _ = gsq.sort(dim=-1, descending=True); cs = ssq.cumsum(-1)
+ topk = {}
+ for k in [1, 5, 10, 25]:
+ idx = max(1, int(d * k / 100)) - 1
+ topk[str(k)] = (cs[:, idx:idx+1] / te).squeeze(-1).mean().item()
+ support = {str(tau): (norms > tau).float().mean().item() for tau in thresholds}
+ result['per_layer'].append({
+ 'layer': l, 'mean_norm': norms.mean().item(), 'median_norm': norms.median().item(),
+ 'support': support, 'r_inf_mean': r_inf.mean().item(), 'pr_mean': pr.mean().item(),
+ 'hoyer_mean': hoyer.mean().item(), 'topk_energy': topk})
+
+ out = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json')
+ with open(out, 'w') as f: json.dump(result, f, indent=2, default=float)
+ s16 = np.mean([ld['support']['1e-06'] for ld in result['per_layer']])
+ mn = np.mean([ld['mean_norm'] for ld in result['per_layer']])
+ ri = np.mean([ld['r_inf_mean'] for ld in result['per_layer']])
+ pr_v = np.mean([ld['pr_mean'] for ld in result['per_layer']])
+ print(f"[{args.method} s={args.seed}] acc={acc:.4f} s(1e-6)={s16:.4f} norm={mn:.2e} r_inf={ri:.4f} PR={pr_v:.4f}", flush=True)
+
+if __name__ == '__main__': main()