""" Clean full sparsity analysis — one method+seed per invocation. Usage: python clean_sparsity_full.py --dataset cifar --method bp --seed 42 --gpu 0 python clean_sparsity_full.py --dataset synth --method bp --seed 42 --alpha 1.0 --depth 4 --gpu 0 """ import os, sys, json, argparse, numpy as np, torch, torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP import torchvision, torchvision.transforms as transforms class StudentBlock(torch.nn.Module): def __init__(self, d, alpha=1.0): super().__init__() self.ln=torch.nn.LayerNorm(d);self.w=torch.nn.Linear(d,d,bias=False) torch.nn.init.normal_(self.w.weight,std=0.01);self.alpha=alpha def forward(self, h): return self.w(((1-self.alpha)*self.ln(h)+self.alpha*torch.tanh(self.ln(h)))) class StudentNet(torch.nn.Module): def __init__(self, d, C, L, alpha=1.0): super().__init__() self.blocks=torch.nn.ModuleList([StudentBlock(d,alpha) for _ in range(L)]) self.out_head=torch.nn.Linear(d,C);self.num_blocks=L;self.d_hidden=d def forward(self, x, return_hidden=False): h=x;hi=[h] if return_hidden else None for b in self.blocks: h=h+b(h) if return_hidden:hi.append(h) lo=self.out_head(h) return (lo,hi) if return_hidden else lo class TeacherNet(torch.nn.Module): def __init__(self, d, C, L, alpha=1.0, seed=0): super().__init__() self.alpha=alpha;rng=torch.Generator().manual_seed(seed) self.Ws=torch.nn.ParameterList() for _ in range(L): W=torch.randn(d,d,generator=rng)*0.3/(d**0.5) U,S,Vh=torch.linalg.svd(W,full_matrices=False) self.Ws.append(torch.nn.Parameter(U@torch.diag(S.clamp(max=0.3))@Vh,requires_grad=False)) self.U=torch.nn.Parameter(torch.randn(C,d,generator=rng)/(d**0.5),requires_grad=False) def forward(self, x): h=x for W in self.Ws:h=h+((1-self.alpha)*h+self.alpha*torch.tanh(h))@W.T return h@self.U.T def main(): p = argparse.ArgumentParser() p.add_argument('--dataset', type=str, required=True, choices=['cifar','synth']) p.add_argument('--method', type=str, required=True) p.add_argument('--seed', type=int, required=True) p.add_argument('--alpha', type=float, default=1.0) p.add_argument('--depth', type=int, default=4) p.add_argument('--gpu', type=int, default=0) p.add_argument('--output_dir', type=str, default='results/confirmatory/clean_sparsity') args = p.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device(f'cuda:{args.gpu}') thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4] if args.dataset == 'cifar': L, d, C = 4, 256, 10 tv = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))]) tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv), 256, False, num_workers=0) for x, y in tel: x = x.view(x.size(0),-1).to(device); y = y.to(device); break ckpt = f'results/confirmatory/checkpoints_A2/{args.method}_s{args.seed}.pt' model = ResidualMLP(3072, d, C, L).to(device) model.load_state_dict(torch.load(ckpt, map_location=device), strict=True) model.eval() is_cifar = True else: L, d, C = args.depth, 128, 10 teacher = TeacherNet(d, C, L, args.alpha, seed=0).to(device) torch.manual_seed(args.seed + 10000) x = torch.randn(512, d, device=device) with torch.no_grad(): y = teacher(x).argmax(-1) ckpt = f'results/confirmatory/checkpoints_A1/a{args.alpha}_L{L}_{args.method}_s{args.seed}.pt' torch.manual_seed(args.seed) model = StudentNet(d, C, L, args.alpha).to(device) model.load_state_dict(torch.load(ckpt, map_location=device), strict=True) model.eval() is_cifar = False batch = x.size(0) Lm = model.num_blocks dm = d if not is_cifar else 256 # BP gradients if is_cifar: h0 = model.embed(x.detach()) else: h0 = x.detach() hs = [h0.clone().requires_grad_(True)] for b in model.blocks: hs.append(hs[-1] + b(hs[-1])) if is_cifar: lo = model.out_head(model.out_ln(hs[-1])) else: lo = model.out_head(hs[-1]) loss = F.cross_entropy(lo, y) acc = (lo.argmax(1) == y).float().mean().item() gs = torch.autograd.grad(loss, hs) bp = {l: gs[l].detach() for l in range(Lm)} # DFA Bs (for Gamma) if is_cifar: torch.manual_seed(args.seed); _ = ResidualMLP(3072, dm, C, Lm) else: torch.manual_seed(args.seed); _ = StudentNet(d, C, Lm, args.alpha) dfa_Bs = [torch.randn(dm, C, device=device)/np.sqrt(C) for _ in range(Lm)] with torch.no_grad(): if is_cifar: logits = model(x) else: logits = model(x) e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1 result = { 'dataset': args.dataset, 'method': args.method, 'seed': args.seed, 'alpha': args.alpha if args.dataset == 'synth' else None, 'depth': L, 'batch': batch, 'loss': loss.item(), 'acc': acc, 'per_layer': [] } for l in range(Lm): g = bp[l] norms = g.norm(dim=-1) # (batch,) log_norms = torch.log10(norms.clamp(min=1e-30)).cpu().numpy() # Support fractions support = {} for tau in thresholds: support[str(tau)] = (norms > tau).float().mean().item() # Element-wise concentration ninf = g.abs().max(dim=-1).values n2 = norms.clamp(min=1e-30) n4 = (g.abs()**4).sum(-1)**(1/4) n1 = g.abs().sum(-1) r_inf = (ninf / n2) pr = (n2**4 / (n4**4).clamp(min=1e-60)) / dm hoyer = (n1 / (n2 * dm**0.5).clamp(min=1e-30))**2 eff_dim = n1**2 / (n.pow(2).sum(-1) * dm).clamp(min=1e-60) if False else n1**2 / ((g**2).sum(-1) * dm).clamp(min=1e-60) gsq = g**2; te = gsq.sum(-1, keepdim=True).clamp(min=1e-60) ssq, _ = gsq.sort(dim=-1, descending=True); cs = ssq.cumsum(-1) topk = {} for k in [1, 5, 10, 25]: idx = max(1, int(dm * k / 100)) - 1 topk[str(k)] = (cs[:, idx:idx+1] / te).squeeze(-1).mean().item() # Gamma (DFA vs BP) — active subset a_dfa = (e_T @ dfa_Bs[l].T).detach() cos_all = F.cosine_similarity(a_dfa, g, dim=-1) gamma_raw = cos_all.mean().item() gamma_active = {}; gamma_ew = {} for tau in thresholds: mask = norms > tau gamma_active[str(tau)] = cos_all[mask].mean().item() if mask.sum() > 0 else None w = norms**2 gamma_ew[str(tau)] = (cos_all * w).sum().item() / (w.sum().item() + 1e-20) layer_data = { 'layer': l, 'mean_norm': norms.mean().item(), 'median_norm': norms.median().item(), 'max_norm': norms.max().item(), 'min_norm': norms.min().item(), 'support': support, 'log_norms_percentiles': {str(p): float(np.percentile(log_norms, p)) for p in [1,5,10,25,50,75,90,95,99]}, 'log_norms_histogram': np.histogram(log_norms, bins=50)[0].tolist(), 'log_norms_bin_edges': np.histogram(log_norms, bins=50)[1].tolist(), 'r_inf_mean': r_inf.mean().item(), 'r_inf_median': r_inf.median().item(), 'pr_mean': pr.mean().item(), 'pr_median': pr.median().item(), 'hoyer_mean': hoyer.mean().item(), 'eff_dim_mean': eff_dim.mean().item(), 'topk_energy': topk, 'gamma_raw': gamma_raw, 'gamma_active': gamma_active, 'gamma_energy_weighted': gamma_ew, } result['per_layer'].append(layer_data) # Summary print tag = f"{args.dataset}_{args.method}_s{args.seed}" if args.dataset == 'synth': tag += f"_a{args.alpha}_L{L}" print(f"[{tag}] acc={acc:.4f} loss={loss.item():.4f}", flush=True) for ld in result['per_layer']: l = ld['layer'] print(f" L{l}: norm={ld['mean_norm']:.2e} s(1e-6)={ld['support']['1e-06']:.4f} " f"r_inf={ld['r_inf_mean']:.4f} PR={ld['pr_mean']:.4f} " f"top1%={ld['topk_energy']['1']:.4f} Gr={ld['gamma_raw']:.4f}", flush=True) out = os.path.join(args.output_dir, f'{tag}.json') with open(out, 'w') as f: json.dump(result, f, indent=2, default=float) print(f" -> {out}", flush=True) if __name__ == '__main__': main()