diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 12:59:01 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-01 12:59:01 -0500 |
| commit | 8eb04f011e7092a8d2e6c89800721a00112fd384 (patch) | |
| tree | 8f68872c79b70a28b3ee7f1f1191e3a00c7c8053 /experiments | |
| parent | d5326053a2e9ce37dd61606aa37fa8f563481f44 (diff) | |
Add clean_sparsity_full.py: independent-process full sparsity analysis
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/clean_sparsity_full.py | 202 |
1 files changed, 202 insertions, 0 deletions
diff --git a/experiments/clean_sparsity_full.py b/experiments/clean_sparsity_full.py new file mode 100644 index 0000000..7a97bda --- /dev/null +++ b/experiments/clean_sparsity_full.py @@ -0,0 +1,202 @@ +""" +Clean full sparsity analysis — one method+seed per invocation. +Usage: python clean_sparsity_full.py --dataset cifar --method bp --seed 42 --gpu 0 + python clean_sparsity_full.py --dataset synth --method bp --seed 42 --alpha 1.0 --depth 4 --gpu 0 +""" +import os, sys, json, argparse, numpy as np, torch, torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP +import torchvision, torchvision.transforms as transforms + +class StudentBlock(torch.nn.Module): + def __init__(self, d, alpha=1.0): + super().__init__() + self.ln=torch.nn.LayerNorm(d);self.w=torch.nn.Linear(d,d,bias=False) + torch.nn.init.normal_(self.w.weight,std=0.01);self.alpha=alpha + def forward(self, h): + return self.w(((1-self.alpha)*self.ln(h)+self.alpha*torch.tanh(self.ln(h)))) + +class StudentNet(torch.nn.Module): + def __init__(self, d, C, L, alpha=1.0): + super().__init__() + self.blocks=torch.nn.ModuleList([StudentBlock(d,alpha) for _ in range(L)]) + self.out_head=torch.nn.Linear(d,C);self.num_blocks=L;self.d_hidden=d + def forward(self, x, return_hidden=False): + h=x;hi=[h] if return_hidden else None + for b in self.blocks: + h=h+b(h) + if return_hidden:hi.append(h) + lo=self.out_head(h) + return (lo,hi) if return_hidden else lo + +class TeacherNet(torch.nn.Module): + def __init__(self, d, C, L, alpha=1.0, seed=0): + super().__init__() + self.alpha=alpha;rng=torch.Generator().manual_seed(seed) + self.Ws=torch.nn.ParameterList() + for _ in range(L): + W=torch.randn(d,d,generator=rng)*0.3/(d**0.5) + U,S,Vh=torch.linalg.svd(W,full_matrices=False) + self.Ws.append(torch.nn.Parameter(U@torch.diag(S.clamp(max=0.3))@Vh,requires_grad=False)) + self.U=torch.nn.Parameter(torch.randn(C,d,generator=rng)/(d**0.5),requires_grad=False) + def forward(self, x): + h=x + for W in self.Ws:h=h+((1-self.alpha)*h+self.alpha*torch.tanh(h))@W.T + return h@self.U.T + +def main(): + p = argparse.ArgumentParser() + p.add_argument('--dataset', type=str, required=True, choices=['cifar','synth']) + p.add_argument('--method', type=str, required=True) + p.add_argument('--seed', type=int, required=True) + p.add_argument('--alpha', type=float, default=1.0) + p.add_argument('--depth', type=int, default=4) + p.add_argument('--gpu', type=int, default=0) + p.add_argument('--output_dir', type=str, default='results/confirmatory/clean_sparsity') + args = p.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + device = torch.device(f'cuda:{args.gpu}') + thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4] + + if args.dataset == 'cifar': + L, d, C = 4, 256, 10 + tv = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) + tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), + 256, False, num_workers=0) + for x, y in tel: x = x.view(x.size(0),-1).to(device); y = y.to(device); break + ckpt = f'results/confirmatory/checkpoints_A2/{args.method}_s{args.seed}.pt' + model = ResidualMLP(3072, d, C, L).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device), strict=True) + model.eval() + is_cifar = True + else: + L, d, C = args.depth, 128, 10 + teacher = TeacherNet(d, C, L, args.alpha, seed=0).to(device) + torch.manual_seed(args.seed + 10000) + x = torch.randn(512, d, device=device) + with torch.no_grad(): y = teacher(x).argmax(-1) + ckpt = f'results/confirmatory/checkpoints_A1/a{args.alpha}_L{L}_{args.method}_s{args.seed}.pt' + torch.manual_seed(args.seed) + model = StudentNet(d, C, L, args.alpha).to(device) + model.load_state_dict(torch.load(ckpt, map_location=device), strict=True) + model.eval() + is_cifar = False + + batch = x.size(0) + Lm = model.num_blocks + dm = d if not is_cifar else 256 + + # BP gradients + if is_cifar: + h0 = model.embed(x.detach()) + else: + h0 = x.detach() + hs = [h0.clone().requires_grad_(True)] + for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) + if is_cifar: + lo = model.out_head(model.out_ln(hs[-1])) + else: + lo = model.out_head(hs[-1]) + loss = F.cross_entropy(lo, y) + acc = (lo.argmax(1) == y).float().mean().item() + gs = torch.autograd.grad(loss, hs) + bp = {l: gs[l].detach() for l in range(Lm)} + + # DFA Bs (for Gamma) + if is_cifar: + torch.manual_seed(args.seed); _ = ResidualMLP(3072, dm, C, Lm) + else: + torch.manual_seed(args.seed); _ = StudentNet(d, C, Lm, args.alpha) + dfa_Bs = [torch.randn(dm, C, device=device)/np.sqrt(C) for _ in range(Lm)] + with torch.no_grad(): + if is_cifar: + logits = model(x) + else: + logits = model(x) + e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 + + result = { + 'dataset': args.dataset, 'method': args.method, 'seed': args.seed, + 'alpha': args.alpha if args.dataset == 'synth' else None, + 'depth': L, 'batch': batch, 'loss': loss.item(), 'acc': acc, + 'per_layer': [] + } + + for l in range(Lm): + g = bp[l] + norms = g.norm(dim=-1) # (batch,) + log_norms = torch.log10(norms.clamp(min=1e-30)).cpu().numpy() + + # Support fractions + support = {} + for tau in thresholds: + support[str(tau)] = (norms > tau).float().mean().item() + + # Element-wise concentration + ninf = g.abs().max(dim=-1).values + n2 = norms.clamp(min=1e-30) + n4 = (g.abs()**4).sum(-1)**(1/4) + n1 = g.abs().sum(-1) + r_inf = (ninf / n2) + pr = (n2**4 / (n4**4).clamp(min=1e-60)) / dm + hoyer = (n1 / (n2 * dm**0.5).clamp(min=1e-30))**2 + eff_dim = n1**2 / (n.pow(2).sum(-1) * dm).clamp(min=1e-60) if False else n1**2 / ((g**2).sum(-1) * dm).clamp(min=1e-60) + gsq = g**2; te = gsq.sum(-1, keepdim=True).clamp(min=1e-60) + ssq, _ = gsq.sort(dim=-1, descending=True); cs = ssq.cumsum(-1) + topk = {} + for k in [1, 5, 10, 25]: + idx = max(1, int(dm * k / 100)) - 1 + topk[str(k)] = (cs[:, idx:idx+1] / te).squeeze(-1).mean().item() + + # Gamma (DFA vs BP) — active subset + a_dfa = (e_T @ dfa_Bs[l].T).detach() + cos_all = F.cosine_similarity(a_dfa, g, dim=-1) + gamma_raw = cos_all.mean().item() + gamma_active = {}; gamma_ew = {} + for tau in thresholds: + mask = norms > tau + gamma_active[str(tau)] = cos_all[mask].mean().item() if mask.sum() > 0 else None + w = norms**2 + gamma_ew[str(tau)] = (cos_all * w).sum().item() / (w.sum().item() + 1e-20) + + layer_data = { + 'layer': l, + 'mean_norm': norms.mean().item(), + 'median_norm': norms.median().item(), + 'max_norm': norms.max().item(), + 'min_norm': norms.min().item(), + 'support': support, + 'log_norms_percentiles': {str(p): float(np.percentile(log_norms, p)) for p in [1,5,10,25,50,75,90,95,99]}, + 'log_norms_histogram': np.histogram(log_norms, bins=50)[0].tolist(), + 'log_norms_bin_edges': np.histogram(log_norms, bins=50)[1].tolist(), + 'r_inf_mean': r_inf.mean().item(), 'r_inf_median': r_inf.median().item(), + 'pr_mean': pr.mean().item(), 'pr_median': pr.median().item(), + 'hoyer_mean': hoyer.mean().item(), + 'eff_dim_mean': eff_dim.mean().item(), + 'topk_energy': topk, + 'gamma_raw': gamma_raw, + 'gamma_active': gamma_active, + 'gamma_energy_weighted': gamma_ew, + } + result['per_layer'].append(layer_data) + + # Summary print + tag = f"{args.dataset}_{args.method}_s{args.seed}" + if args.dataset == 'synth': tag += f"_a{args.alpha}_L{L}" + print(f"[{tag}] acc={acc:.4f} loss={loss.item():.4f}", flush=True) + for ld in result['per_layer']: + l = ld['layer'] + print(f" L{l}: norm={ld['mean_norm']:.2e} s(1e-6)={ld['support']['1e-06']:.4f} " + f"r_inf={ld['r_inf_mean']:.4f} PR={ld['pr_mean']:.4f} " + f"top1%={ld['topk_energy']['1']:.4f} Gr={ld['gamma_raw']:.4f}", flush=True) + + out = os.path.join(args.output_dir, f'{tag}.json') + with open(out, 'w') as f: + json.dump(result, f, indent=2, default=float) + print(f" -> {out}", flush=True) + +if __name__ == '__main__': + main() |
