diff options
Diffstat (limited to 'files/experiments/lyapunov_speedup_benchmark.py')
| -rw-r--r-- | files/experiments/lyapunov_speedup_benchmark.py | 638 |
1 files changed, 638 insertions, 0 deletions
diff --git a/files/experiments/lyapunov_speedup_benchmark.py b/files/experiments/lyapunov_speedup_benchmark.py new file mode 100644 index 0000000..117009b --- /dev/null +++ b/files/experiments/lyapunov_speedup_benchmark.py @@ -0,0 +1,638 @@ +""" +Lyapunov Computation Speedup Benchmark + +Tests different optimization approaches for computing Lyapunov exponents +during SNN training. All approaches should produce equivalent results +(within numerical precision) but with different performance characteristics. + +Approaches tested: +- Baseline: Current sequential implementation +- Approach A: Trajectory-as-batch (P=2), share first Linear +- Approach B: Global-norm divergence + single-scale renorm +- Approach C: torch.compile the time loop +- Combined: A + B + C together +""" + +import os +import sys +import time +from typing import Tuple, Optional, List +from dataclasses import dataclass + +import torch +import torch.nn as nn +import snntorch as snn +from snntorch import surrogate + +# Ensure we can import from project +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + + +# ============================================================================= +# Baseline Implementation (Current) +# ============================================================================= + +class BaselineSNN(nn.Module): + """Current implementation: sequential perturbed trajectory.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + # Simple feedforward for benchmarking (not full VGG) + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims # Flattened input + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) # Flatten + + # Init membrane potentials + mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Original trajectory + h = x + new_mems = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h = lin(h) + spk, mem = lif(h, mems[i]) + new_mems.append(mem) + h = spk + mems = new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory (SEPARATE PASS - this is slow) + h_p = x + new_mems_p = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h_p = lin(h_p) + spk_p, mem_p = lif(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = spk_p + + # Divergence (per-layer norms, then sum) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + delta_sq += (diff ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize (per-layer - SLOW) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12 + new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm + + mems_p = new_mems_p + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach A: Trajectory-as-batch (P=2), share first Linear +# ============================================================================= + +class ApproachA_SNN(nn.Module): + """Batch both trajectories together, share Linear_1.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + P = 2 if compute_lyapunov else 1 + + # State layout: (P, B, H) where P=2 for [original, perturbed] + mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + # Initialize perturbed state + for i in range(len(self.hidden_dims)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Layer 1: compute Linear ONCE, expand to (P, B, H1) + h1 = self.linears[0](x) # (B, H1) - computed ONCE + + if compute_lyapunov: + h = h1.unsqueeze(0).expand(P, -1, -1) # (P, B, H1) - zero-copy view + else: + h = h1.unsqueeze(0) # (1, B, H1) + + # LIF layer 1 + spk, mems[0] = self.lifs[0](h, mems[0]) + h = spk + + # Layers 2+: different inputs for each trajectory + for i in range(1, len(self.hidden_dims)): + # Reshape to (P*B, H) for batched Linear + h_flat = h.reshape(P * B, -1) + h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i]) + spk, mems[i] = self.lifs[i](h_lin, mems[i]) + h = spk + + # Accumulate spikes from original trajectory only + spike_sum = spike_sum + h[0] + + if compute_lyapunov: + # Global divergence across all layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] # (B, H_i) + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # Renormalize with global scale (per-layer still, but simpler) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12 + mems[i][1] = mems[i][0] + lyap_eps * diff / norm + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach B: Global-norm divergence + single-scale renorm +# ============================================================================= + +class ApproachB_SNN(nn.Module): + """Global norm for divergence, single scale factor for renorm.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Original trajectory + h = x + new_mems = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h = lin(h) + spk, mem = lif(h, mems[i]) + new_mems.append(mem) + h = spk + mems = new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory + h_p = x + new_mems_p = [] + for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): + h_p = lin(h_p) + spk_p, mem_p = lif(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = spk_p + + # GLOBAL divergence (one delta per batch element) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # SINGLE SCALE renormalization (key optimization) + scale = (lyap_eps / delta).unsqueeze(-1) # (B, 1) + for i in range(len(self.hidden_dims)): + diff = new_mems_p[i] - new_mems[i] + new_mems_p[i] = new_mems[i] + diff * scale + + mems_p = new_mems_p + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach A+B Combined: Batched trajectories + global renorm +# ============================================================================= + +class ApproachAB_SNN(nn.Module): + """Combined: trajectory-as-batch + global-norm renorm.""" + + def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): + super().__init__() + self.T = T + self.hidden_dims = hidden_dims + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + + dims = [in_channels * 32 * 32] + hidden_dims + for i in range(len(hidden_dims)): + self.linears.append(nn.Linear(dims[i], dims[i+1])) + self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, + spike_grad=spike_grad, init_hidden=False)) + + self.readout = nn.Linear(hidden_dims[-1], 10) + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + x = x.view(B, -1) + + P = 2 if compute_lyapunov else 1 + + # State: (P, B, H) + mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims] + + if compute_lyapunov: + for i in range(len(self.hidden_dims)): + mems[i][1] = lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + for t in range(self.T): + # Layer 1: Linear computed ONCE + h1 = self.linears[0](x) + h = h1.unsqueeze(0).expand(P, -1, -1) if compute_lyapunov else h1.unsqueeze(0) + + spk, mems[0] = self.lifs[0](h, mems[0]) + h = spk + + # Layers 2+ + for i in range(1, len(self.hidden_dims)): + h_flat = h.reshape(P * B, -1) + h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i]) + spk, mems[i] = self.lifs[i](h_lin, mems[i]) + h = spk + + spike_sum = spike_sum + h[0] + + if compute_lyapunov: + # Global divergence + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + delta_sq = delta_sq + diff.square().sum(dim=-1) + + delta = (delta_sq + 1e-12).sqrt() + lyap_accum = lyap_accum + (delta / lyap_eps).log() + + # Global scale renorm + scale = (lyap_eps / delta).unsqueeze(-1) + for i in range(len(self.hidden_dims)): + diff = mems[i][1] - mems[i][0] + mems[i][1] = mems[i][0] + diff * scale + + logits = self.readout(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +# ============================================================================= +# Approach C: torch.compile wrapper +# ============================================================================= + +def make_compiled_model(model_class, *args, **kwargs): + """Create a model and compile its forward pass.""" + model = model_class(*args, **kwargs) + # Compile the forward method + model.forward = torch.compile(model.forward, mode="reduce-overhead") + return model + + +# ============================================================================= +# Benchmarking +# ============================================================================= + +@dataclass +class BenchmarkResult: + name: str + forward_time_ms: float + backward_time_ms: float + total_time_ms: float + lyap_value: float + memory_mb: float + + def __str__(self): + return (f"{self.name:<25} | Fwd: {self.forward_time_ms:7.2f}ms | " + f"Bwd: {self.backward_time_ms:7.2f}ms | " + f"Total: {self.total_time_ms:7.2f}ms | " + f"λ: {self.lyap_value:+.4f} | Mem: {self.memory_mb:.1f}MB") + + +def benchmark_model( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + name: str, + warmup_iters: int = 5, + bench_iters: int = 20, +) -> BenchmarkResult: + """Benchmark a single model configuration.""" + + device = x.device + criterion = nn.CrossEntropyLoss() + + # Warmup + for _ in range(warmup_iters): + logits, lyap = model(x, compute_lyapunov=True) + loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0) + loss.backward() + model.zero_grad() + + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + fwd_times = [] + bwd_times = [] + lyap_vals = [] + + for _ in range(bench_iters): + # Forward + torch.cuda.synchronize() + t0 = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=True) + loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0) + + torch.cuda.synchronize() + t1 = time.perf_counter() + + # Backward + loss.backward() + + torch.cuda.synchronize() + t2 = time.perf_counter() + + fwd_times.append((t1 - t0) * 1000) + bwd_times.append((t2 - t1) * 1000) + if lyap is not None: + lyap_vals.append(lyap.item()) + + model.zero_grad() + + peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 + + return BenchmarkResult( + name=name, + forward_time_ms=sum(fwd_times) / len(fwd_times), + backward_time_ms=sum(bwd_times) / len(bwd_times), + total_time_ms=sum(fwd_times) / len(fwd_times) + sum(bwd_times) / len(bwd_times), + lyap_value=sum(lyap_vals) / len(lyap_vals) if lyap_vals else 0.0, + memory_mb=peak_mem, + ) + + +def run_benchmarks( + batch_size: int = 64, + T: int = 4, + hidden_dims: List[int] = [64, 128, 256], + device: str = "cuda", +): + """Run all benchmarks and compare.""" + + print("=" * 80) + print("LYAPUNOV COMPUTATION SPEEDUP BENCHMARK") + print("=" * 80) + print(f"Batch size: {batch_size}") + print(f"Timesteps: {T}") + print(f"Hidden dims: {hidden_dims}") + print(f"Device: {device}") + print("=" * 80) + + # Create dummy data + x = torch.randn(batch_size, 3, 32, 32, device=device) + y = torch.randint(0, 10, (batch_size,), device=device) + + results = [] + + # 1. Baseline + print("\n[1/6] Benchmarking Baseline...") + model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "Baseline")) + del model + torch.cuda.empty_cache() + + # 2. Approach A (batched trajectories) + print("[2/6] Benchmarking Approach A (batched)...") + model = ApproachA_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "A: Batched trajectories")) + del model + torch.cuda.empty_cache() + + # 3. Approach B (global renorm) + print("[3/6] Benchmarking Approach B (global renorm)...") + model = ApproachB_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "B: Global renorm")) + del model + torch.cuda.empty_cache() + + # 4. Approach A+B combined + print("[4/6] Benchmarking Approach A+B (combined)...") + model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device) + results.append(benchmark_model(model, x, y, "A+B: Combined")) + del model + torch.cuda.empty_cache() + + # 5. Approach C (torch.compile on baseline) + print("[5/6] Benchmarking Approach C (compiled baseline)...") + try: + model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device) + model.forward = torch.compile(model.forward, mode="reduce-overhead") + results.append(benchmark_model(model, x, y, "C: Compiled baseline", warmup_iters=10)) + del model + torch.cuda.empty_cache() + except Exception as e: + print(f" torch.compile failed: {e}") + results.append(BenchmarkResult("C: Compiled baseline", 0, 0, 0, 0, 0)) + + # 6. A+B+C (all combined) + print("[6/6] Benchmarking A+B+C (all optimizations)...") + try: + model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device) + model.forward = torch.compile(model.forward, mode="reduce-overhead") + results.append(benchmark_model(model, x, y, "A+B+C: All optimized", warmup_iters=10)) + del model + torch.cuda.empty_cache() + except Exception as e: + print(f" torch.compile failed: {e}") + results.append(BenchmarkResult("A+B+C: All optimized", 0, 0, 0, 0, 0)) + + # Print results + print("\n" + "=" * 80) + print("RESULTS") + print("=" * 80) + + baseline_time = results[0].total_time_ms + + for r in results: + print(r) + + print("\n" + "-" * 80) + print("SPEEDUP vs BASELINE:") + print("-" * 80) + + for r in results[1:]: + if r.total_time_ms > 0: + speedup = baseline_time / r.total_time_ms + print(f" {r.name:<25}: {speedup:.2f}x") + + # Verify Lyapunov values are consistent + print("\n" + "-" * 80) + print("LYAPUNOV VALUE CONSISTENCY CHECK:") + print("-" * 80) + + base_lyap = results[0].lyap_value + for r in results[1:]: + if r.lyap_value != 0: + diff = abs(r.lyap_value - base_lyap) + status = "✓" if diff < 0.1 else "✗" + print(f" {r.name:<25}: λ={r.lyap_value:+.4f} (diff={diff:.4f}) {status}") + + return results + + +def run_scaling_test(device: str = "cuda"): + """Test how approaches scale with batch size and timesteps.""" + + print("\n" + "=" * 80) + print("SCALING TESTS") + print("=" * 80) + + configs = [ + {"batch_size": 32, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 128, "T": 4, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 8, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 16, "hidden_dims": [64, 128, 256]}, + {"batch_size": 64, "T": 4, "hidden_dims": [128, 256, 512]}, # Larger model + ] + + print(f"{'Config':<40} | {'Baseline':<12} | {'A+B':<12} | {'Speedup':<8}") + print("-" * 80) + + for cfg in configs: + x = torch.randn(cfg["batch_size"], 3, 32, 32, device=device) + y = torch.randint(0, 10, (cfg["batch_size"],), device=device) + + # Baseline + model_base = BaselineSNN(**cfg).to(device) + r_base = benchmark_model(model_base, x, y, "base", warmup_iters=3, bench_iters=10) + del model_base + + # A+B + model_ab = ApproachAB_SNN(**cfg).to(device) + r_ab = benchmark_model(model_ab, x, y, "a+b", warmup_iters=3, bench_iters=10) + del model_ab + + torch.cuda.empty_cache() + + speedup = r_base.total_time_ms / r_ab.total_time_ms if r_ab.total_time_ms > 0 else 0 + + cfg_str = f"B={cfg['batch_size']}, T={cfg['T']}, H={cfg['hidden_dims']}" + print(f"{cfg_str:<40} | {r_base.total_time_ms:>10.2f}ms | {r_ab.total_time_ms:>10.2f}ms | {speedup:>6.2f}x") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--T", type=int, default=4) + parser.add_argument("--hidden_dims", type=int, nargs="+", default=[64, 128, 256]) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--scaling", action="store_true", help="Run scaling tests") + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA not available, using CPU (results will not be representative)") + args.device = "cpu" + + # Main benchmark + results = run_benchmarks( + batch_size=args.batch_size, + T=args.T, + hidden_dims=args.hidden_dims, + device=args.device, + ) + + # Scaling tests + if args.scaling: + run_scaling_test(args.device) |
