summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 12:59:01 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-01 12:59:01 -0500
commit8eb04f011e7092a8d2e6c89800721a00112fd384 (patch)
tree8f68872c79b70a28b3ee7f1f1191e3a00c7c8053 /experiments
parentd5326053a2e9ce37dd61606aa37fa8f563481f44 (diff)
Add clean_sparsity_full.py: independent-process full sparsity analysis
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/clean_sparsity_full.py202
1 files changed, 202 insertions, 0 deletions
diff --git a/experiments/clean_sparsity_full.py b/experiments/clean_sparsity_full.py
new file mode 100644
index 0000000..7a97bda
--- /dev/null
+++ b/experiments/clean_sparsity_full.py
@@ -0,0 +1,202 @@
+"""
+Clean full sparsity analysis — one method+seed per invocation.
+Usage: python clean_sparsity_full.py --dataset cifar --method bp --seed 42 --gpu 0
+ python clean_sparsity_full.py --dataset synth --method bp --seed 42 --alpha 1.0 --depth 4 --gpu 0
+"""
+import os, sys, json, argparse, numpy as np, torch, torch.nn.functional as F
+from torch.utils.data import DataLoader, TensorDataset
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+import torchvision, torchvision.transforms as transforms
+
+class StudentBlock(torch.nn.Module):
+ def __init__(self, d, alpha=1.0):
+ super().__init__()
+ self.ln=torch.nn.LayerNorm(d);self.w=torch.nn.Linear(d,d,bias=False)
+ torch.nn.init.normal_(self.w.weight,std=0.01);self.alpha=alpha
+ def forward(self, h):
+ return self.w(((1-self.alpha)*self.ln(h)+self.alpha*torch.tanh(self.ln(h))))
+
+class StudentNet(torch.nn.Module):
+ def __init__(self, d, C, L, alpha=1.0):
+ super().__init__()
+ self.blocks=torch.nn.ModuleList([StudentBlock(d,alpha) for _ in range(L)])
+ self.out_head=torch.nn.Linear(d,C);self.num_blocks=L;self.d_hidden=d
+ def forward(self, x, return_hidden=False):
+ h=x;hi=[h] if return_hidden else None
+ for b in self.blocks:
+ h=h+b(h)
+ if return_hidden:hi.append(h)
+ lo=self.out_head(h)
+ return (lo,hi) if return_hidden else lo
+
+class TeacherNet(torch.nn.Module):
+ def __init__(self, d, C, L, alpha=1.0, seed=0):
+ super().__init__()
+ self.alpha=alpha;rng=torch.Generator().manual_seed(seed)
+ self.Ws=torch.nn.ParameterList()
+ for _ in range(L):
+ W=torch.randn(d,d,generator=rng)*0.3/(d**0.5)
+ U,S,Vh=torch.linalg.svd(W,full_matrices=False)
+ self.Ws.append(torch.nn.Parameter(U@torch.diag(S.clamp(max=0.3))@Vh,requires_grad=False))
+ self.U=torch.nn.Parameter(torch.randn(C,d,generator=rng)/(d**0.5),requires_grad=False)
+ def forward(self, x):
+ h=x
+ for W in self.Ws:h=h+((1-self.alpha)*h+self.alpha*torch.tanh(h))@W.T
+ return h@self.U.T
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument('--dataset', type=str, required=True, choices=['cifar','synth'])
+ p.add_argument('--method', type=str, required=True)
+ p.add_argument('--seed', type=int, required=True)
+ p.add_argument('--alpha', type=float, default=1.0)
+ p.add_argument('--depth', type=int, default=4)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/confirmatory/clean_sparsity')
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ device = torch.device(f'cuda:{args.gpu}')
+ thresholds = [1e-8, 1e-7, 1e-6, 1e-5, 1e-4]
+
+ if args.dataset == 'cifar':
+ L, d, C = 4, 256, 10
+ tv = transforms.Compose([transforms.ToTensor(),
+ transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616))])
+ tel = DataLoader(torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv),
+ 256, False, num_workers=0)
+ for x, y in tel: x = x.view(x.size(0),-1).to(device); y = y.to(device); break
+ ckpt = f'results/confirmatory/checkpoints_A2/{args.method}_s{args.seed}.pt'
+ model = ResidualMLP(3072, d, C, L).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device), strict=True)
+ model.eval()
+ is_cifar = True
+ else:
+ L, d, C = args.depth, 128, 10
+ teacher = TeacherNet(d, C, L, args.alpha, seed=0).to(device)
+ torch.manual_seed(args.seed + 10000)
+ x = torch.randn(512, d, device=device)
+ with torch.no_grad(): y = teacher(x).argmax(-1)
+ ckpt = f'results/confirmatory/checkpoints_A1/a{args.alpha}_L{L}_{args.method}_s{args.seed}.pt'
+ torch.manual_seed(args.seed)
+ model = StudentNet(d, C, L, args.alpha).to(device)
+ model.load_state_dict(torch.load(ckpt, map_location=device), strict=True)
+ model.eval()
+ is_cifar = False
+
+ batch = x.size(0)
+ Lm = model.num_blocks
+ dm = d if not is_cifar else 256
+
+ # BP gradients
+ if is_cifar:
+ h0 = model.embed(x.detach())
+ else:
+ h0 = x.detach()
+ hs = [h0.clone().requires_grad_(True)]
+ for b in model.blocks: hs.append(hs[-1] + b(hs[-1]))
+ if is_cifar:
+ lo = model.out_head(model.out_ln(hs[-1]))
+ else:
+ lo = model.out_head(hs[-1])
+ loss = F.cross_entropy(lo, y)
+ acc = (lo.argmax(1) == y).float().mean().item()
+ gs = torch.autograd.grad(loss, hs)
+ bp = {l: gs[l].detach() for l in range(Lm)}
+
+ # DFA Bs (for Gamma)
+ if is_cifar:
+ torch.manual_seed(args.seed); _ = ResidualMLP(3072, dm, C, Lm)
+ else:
+ torch.manual_seed(args.seed); _ = StudentNet(d, C, Lm, args.alpha)
+ dfa_Bs = [torch.randn(dm, C, device=device)/np.sqrt(C) for _ in range(Lm)]
+ with torch.no_grad():
+ if is_cifar:
+ logits = model(x)
+ else:
+ logits = model(x)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1
+
+ result = {
+ 'dataset': args.dataset, 'method': args.method, 'seed': args.seed,
+ 'alpha': args.alpha if args.dataset == 'synth' else None,
+ 'depth': L, 'batch': batch, 'loss': loss.item(), 'acc': acc,
+ 'per_layer': []
+ }
+
+ for l in range(Lm):
+ g = bp[l]
+ norms = g.norm(dim=-1) # (batch,)
+ log_norms = torch.log10(norms.clamp(min=1e-30)).cpu().numpy()
+
+ # Support fractions
+ support = {}
+ for tau in thresholds:
+ support[str(tau)] = (norms > tau).float().mean().item()
+
+ # Element-wise concentration
+ ninf = g.abs().max(dim=-1).values
+ n2 = norms.clamp(min=1e-30)
+ n4 = (g.abs()**4).sum(-1)**(1/4)
+ n1 = g.abs().sum(-1)
+ r_inf = (ninf / n2)
+ pr = (n2**4 / (n4**4).clamp(min=1e-60)) / dm
+ hoyer = (n1 / (n2 * dm**0.5).clamp(min=1e-30))**2
+ eff_dim = n1**2 / (n.pow(2).sum(-1) * dm).clamp(min=1e-60) if False else n1**2 / ((g**2).sum(-1) * dm).clamp(min=1e-60)
+ gsq = g**2; te = gsq.sum(-1, keepdim=True).clamp(min=1e-60)
+ ssq, _ = gsq.sort(dim=-1, descending=True); cs = ssq.cumsum(-1)
+ topk = {}
+ for k in [1, 5, 10, 25]:
+ idx = max(1, int(dm * k / 100)) - 1
+ topk[str(k)] = (cs[:, idx:idx+1] / te).squeeze(-1).mean().item()
+
+ # Gamma (DFA vs BP) — active subset
+ a_dfa = (e_T @ dfa_Bs[l].T).detach()
+ cos_all = F.cosine_similarity(a_dfa, g, dim=-1)
+ gamma_raw = cos_all.mean().item()
+ gamma_active = {}; gamma_ew = {}
+ for tau in thresholds:
+ mask = norms > tau
+ gamma_active[str(tau)] = cos_all[mask].mean().item() if mask.sum() > 0 else None
+ w = norms**2
+ gamma_ew[str(tau)] = (cos_all * w).sum().item() / (w.sum().item() + 1e-20)
+
+ layer_data = {
+ 'layer': l,
+ 'mean_norm': norms.mean().item(),
+ 'median_norm': norms.median().item(),
+ 'max_norm': norms.max().item(),
+ 'min_norm': norms.min().item(),
+ 'support': support,
+ 'log_norms_percentiles': {str(p): float(np.percentile(log_norms, p)) for p in [1,5,10,25,50,75,90,95,99]},
+ 'log_norms_histogram': np.histogram(log_norms, bins=50)[0].tolist(),
+ 'log_norms_bin_edges': np.histogram(log_norms, bins=50)[1].tolist(),
+ 'r_inf_mean': r_inf.mean().item(), 'r_inf_median': r_inf.median().item(),
+ 'pr_mean': pr.mean().item(), 'pr_median': pr.median().item(),
+ 'hoyer_mean': hoyer.mean().item(),
+ 'eff_dim_mean': eff_dim.mean().item(),
+ 'topk_energy': topk,
+ 'gamma_raw': gamma_raw,
+ 'gamma_active': gamma_active,
+ 'gamma_energy_weighted': gamma_ew,
+ }
+ result['per_layer'].append(layer_data)
+
+ # Summary print
+ tag = f"{args.dataset}_{args.method}_s{args.seed}"
+ if args.dataset == 'synth': tag += f"_a{args.alpha}_L{L}"
+ print(f"[{tag}] acc={acc:.4f} loss={loss.item():.4f}", flush=True)
+ for ld in result['per_layer']:
+ l = ld['layer']
+ print(f" L{l}: norm={ld['mean_norm']:.2e} s(1e-6)={ld['support']['1e-06']:.4f} "
+ f"r_inf={ld['r_inf_mean']:.4f} PR={ld['pr_mean']:.4f} "
+ f"top1%={ld['topk_energy']['1']:.4f} Gr={ld['gamma_raw']:.4f}", flush=True)
+
+ out = os.path.join(args.output_dir, f'{tag}.json')
+ with open(out, 'w') as f:
+ json.dump(result, f, indent=2, default=float)
+ print(f" -> {out}", flush=True)
+
+if __name__ == '__main__':
+ main()