""" Lyapunov Computation Speedup Benchmark Tests different optimization approaches for computing Lyapunov exponents during SNN training. All approaches should produce equivalent results (within numerical precision) but with different performance characteristics. Approaches tested: - Baseline: Current sequential implementation - Approach A: Trajectory-as-batch (P=2), share first Linear - Approach B: Global-norm divergence + single-scale renorm - Approach C: torch.compile the time loop - Combined: A + B + C together """ import os import sys import time from typing import Tuple, Optional, List from dataclasses import dataclass import torch import torch.nn as nn import snntorch as snn from snntorch import surrogate # Ensure we can import from project _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) # ============================================================================= # Baseline Implementation (Current) # ============================================================================= class BaselineSNN(nn.Module): """Current implementation: sequential perturbed trajectory.""" def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): super().__init__() self.T = T self.hidden_dims = hidden_dims spike_grad = surrogate.fast_sigmoid(slope=25) # Simple feedforward for benchmarking (not full VGG) self.linears = nn.ModuleList() self.lifs = nn.ModuleList() dims = [in_channels * 32 * 32] + hidden_dims # Flattened input for i in range(len(hidden_dims)): self.linears.append(nn.Linear(dims[i], dims[i+1])) self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, spike_grad=spike_grad, init_hidden=False)) self.readout = nn.Linear(hidden_dims[-1], 10) def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): B = x.size(0) device, dtype = x.device, x.dtype x = x.view(B, -1) # Flatten # Init membrane potentials mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims] if compute_lyapunov: mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] lyap_accum = torch.zeros(B, device=device, dtype=dtype) spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) for t in range(self.T): # Original trajectory h = x new_mems = [] for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): h = lin(h) spk, mem = lif(h, mems[i]) new_mems.append(mem) h = spk mems = new_mems spike_sum = spike_sum + h if compute_lyapunov: # Perturbed trajectory (SEPARATE PASS - this is slow) h_p = x new_mems_p = [] for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): h_p = lin(h_p) spk_p, mem_p = lif(h_p, mems_p[i]) new_mems_p.append(mem_p) h_p = spk_p # Divergence (per-layer norms, then sum) delta_sq = torch.zeros(B, device=device, dtype=dtype) for i in range(len(self.hidden_dims)): diff = new_mems_p[i] - new_mems[i] delta_sq += (diff ** 2).sum(dim=1) delta = torch.sqrt(delta_sq + 1e-12) lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) # Renormalize (per-layer - SLOW) for i in range(len(self.hidden_dims)): diff = new_mems_p[i] - new_mems[i] norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12 new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm mems_p = new_mems_p logits = self.readout(spike_sum) lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None return logits, lyap_est # ============================================================================= # Approach A: Trajectory-as-batch (P=2), share first Linear # ============================================================================= class ApproachA_SNN(nn.Module): """Batch both trajectories together, share Linear_1.""" def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): super().__init__() self.T = T self.hidden_dims = hidden_dims spike_grad = surrogate.fast_sigmoid(slope=25) self.linears = nn.ModuleList() self.lifs = nn.ModuleList() dims = [in_channels * 32 * 32] + hidden_dims for i in range(len(hidden_dims)): self.linears.append(nn.Linear(dims[i], dims[i+1])) self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, spike_grad=spike_grad, init_hidden=False)) self.readout = nn.Linear(hidden_dims[-1], 10) def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): B = x.size(0) device, dtype = x.device, x.dtype x = x.view(B, -1) P = 2 if compute_lyapunov else 1 # State layout: (P, B, H) where P=2 for [original, perturbed] mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims] if compute_lyapunov: # Initialize perturbed state for i in range(len(self.hidden_dims)): mems[i][1] = mems[i][0] + lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype) lyap_accum = torch.zeros(B, device=device, dtype=dtype) spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) for t in range(self.T): # Layer 1: compute Linear ONCE, expand to (P, B, H1) h1 = self.linears[0](x) # (B, H1) - computed ONCE if compute_lyapunov: h = h1.unsqueeze(0).expand(P, -1, -1) # (P, B, H1) - zero-copy view else: h = h1.unsqueeze(0) # (1, B, H1) # LIF layer 1 spk, mems[0] = self.lifs[0](h, mems[0]) h = spk # Layers 2+: different inputs for each trajectory for i in range(1, len(self.hidden_dims)): # Reshape to (P*B, H) for batched Linear h_flat = h.reshape(P * B, -1) h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i]) spk, mems[i] = self.lifs[i](h_lin, mems[i]) h = spk # Accumulate spikes from original trajectory only spike_sum = spike_sum + h[0] if compute_lyapunov: # Global divergence across all layers delta_sq = torch.zeros(B, device=device, dtype=dtype) for i in range(len(self.hidden_dims)): diff = mems[i][1] - mems[i][0] # (B, H_i) delta_sq = delta_sq + diff.square().sum(dim=-1) delta = (delta_sq + 1e-12).sqrt() lyap_accum = lyap_accum + (delta / lyap_eps).log() # Renormalize with global scale (per-layer still, but simpler) for i in range(len(self.hidden_dims)): diff = mems[i][1] - mems[i][0] norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12 mems[i][1] = mems[i][0] + lyap_eps * diff / norm logits = self.readout(spike_sum) lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None return logits, lyap_est # ============================================================================= # Approach B: Global-norm divergence + single-scale renorm # ============================================================================= class ApproachB_SNN(nn.Module): """Global norm for divergence, single scale factor for renorm.""" def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): super().__init__() self.T = T self.hidden_dims = hidden_dims spike_grad = surrogate.fast_sigmoid(slope=25) self.linears = nn.ModuleList() self.lifs = nn.ModuleList() dims = [in_channels * 32 * 32] + hidden_dims for i in range(len(hidden_dims)): self.linears.append(nn.Linear(dims[i], dims[i+1])) self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, spike_grad=spike_grad, init_hidden=False)) self.readout = nn.Linear(hidden_dims[-1], 10) def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): B = x.size(0) device, dtype = x.device, x.dtype x = x.view(B, -1) mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims] if compute_lyapunov: mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] lyap_accum = torch.zeros(B, device=device, dtype=dtype) spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) for t in range(self.T): # Original trajectory h = x new_mems = [] for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): h = lin(h) spk, mem = lif(h, mems[i]) new_mems.append(mem) h = spk mems = new_mems spike_sum = spike_sum + h if compute_lyapunov: # Perturbed trajectory h_p = x new_mems_p = [] for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)): h_p = lin(h_p) spk_p, mem_p = lif(h_p, mems_p[i]) new_mems_p.append(mem_p) h_p = spk_p # GLOBAL divergence (one delta per batch element) delta_sq = torch.zeros(B, device=device, dtype=dtype) for i in range(len(self.hidden_dims)): diff = new_mems_p[i] - new_mems[i] delta_sq = delta_sq + diff.square().sum(dim=-1) delta = (delta_sq + 1e-12).sqrt() lyap_accum = lyap_accum + (delta / lyap_eps).log() # SINGLE SCALE renormalization (key optimization) scale = (lyap_eps / delta).unsqueeze(-1) # (B, 1) for i in range(len(self.hidden_dims)): diff = new_mems_p[i] - new_mems[i] new_mems_p[i] = new_mems[i] + diff * scale mems_p = new_mems_p logits = self.readout(spike_sum) lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None return logits, lyap_est # ============================================================================= # Approach A+B Combined: Batched trajectories + global renorm # ============================================================================= class ApproachAB_SNN(nn.Module): """Combined: trajectory-as-batch + global-norm renorm.""" def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9): super().__init__() self.T = T self.hidden_dims = hidden_dims spike_grad = surrogate.fast_sigmoid(slope=25) self.linears = nn.ModuleList() self.lifs = nn.ModuleList() dims = [in_channels * 32 * 32] + hidden_dims for i in range(len(hidden_dims)): self.linears.append(nn.Linear(dims[i], dims[i+1])) self.lifs.append(snn.Leaky(beta=beta, threshold=1.0, spike_grad=spike_grad, init_hidden=False)) self.readout = nn.Linear(hidden_dims[-1], 10) def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): B = x.size(0) device, dtype = x.device, x.dtype x = x.view(B, -1) P = 2 if compute_lyapunov else 1 # State: (P, B, H) mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims] if compute_lyapunov: for i in range(len(self.hidden_dims)): mems[i][1] = lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype) lyap_accum = torch.zeros(B, device=device, dtype=dtype) spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) for t in range(self.T): # Layer 1: Linear computed ONCE h1 = self.linears[0](x) h = h1.unsqueeze(0).expand(P, -1, -1) if compute_lyapunov else h1.unsqueeze(0) spk, mems[0] = self.lifs[0](h, mems[0]) h = spk # Layers 2+ for i in range(1, len(self.hidden_dims)): h_flat = h.reshape(P * B, -1) h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i]) spk, mems[i] = self.lifs[i](h_lin, mems[i]) h = spk spike_sum = spike_sum + h[0] if compute_lyapunov: # Global divergence delta_sq = torch.zeros(B, device=device, dtype=dtype) for i in range(len(self.hidden_dims)): diff = mems[i][1] - mems[i][0] delta_sq = delta_sq + diff.square().sum(dim=-1) delta = (delta_sq + 1e-12).sqrt() lyap_accum = lyap_accum + (delta / lyap_eps).log() # Global scale renorm scale = (lyap_eps / delta).unsqueeze(-1) for i in range(len(self.hidden_dims)): diff = mems[i][1] - mems[i][0] mems[i][1] = mems[i][0] + diff * scale logits = self.readout(spike_sum) lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None return logits, lyap_est # ============================================================================= # Approach C: torch.compile wrapper # ============================================================================= def make_compiled_model(model_class, *args, **kwargs): """Create a model and compile its forward pass.""" model = model_class(*args, **kwargs) # Compile the forward method model.forward = torch.compile(model.forward, mode="reduce-overhead") return model # ============================================================================= # Benchmarking # ============================================================================= @dataclass class BenchmarkResult: name: str forward_time_ms: float backward_time_ms: float total_time_ms: float lyap_value: float memory_mb: float def __str__(self): return (f"{self.name:<25} | Fwd: {self.forward_time_ms:7.2f}ms | " f"Bwd: {self.backward_time_ms:7.2f}ms | " f"Total: {self.total_time_ms:7.2f}ms | " f"λ: {self.lyap_value:+.4f} | Mem: {self.memory_mb:.1f}MB") def benchmark_model( model: nn.Module, x: torch.Tensor, y: torch.Tensor, name: str, warmup_iters: int = 5, bench_iters: int = 20, ) -> BenchmarkResult: """Benchmark a single model configuration.""" device = x.device criterion = nn.CrossEntropyLoss() # Warmup for _ in range(warmup_iters): logits, lyap = model(x, compute_lyapunov=True) loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0) loss.backward() model.zero_grad() torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() fwd_times = [] bwd_times = [] lyap_vals = [] for _ in range(bench_iters): # Forward torch.cuda.synchronize() t0 = time.perf_counter() logits, lyap = model(x, compute_lyapunov=True) loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0) torch.cuda.synchronize() t1 = time.perf_counter() # Backward loss.backward() torch.cuda.synchronize() t2 = time.perf_counter() fwd_times.append((t1 - t0) * 1000) bwd_times.append((t2 - t1) * 1000) if lyap is not None: lyap_vals.append(lyap.item()) model.zero_grad() peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 return BenchmarkResult( name=name, forward_time_ms=sum(fwd_times) / len(fwd_times), backward_time_ms=sum(bwd_times) / len(bwd_times), total_time_ms=sum(fwd_times) / len(fwd_times) + sum(bwd_times) / len(bwd_times), lyap_value=sum(lyap_vals) / len(lyap_vals) if lyap_vals else 0.0, memory_mb=peak_mem, ) def run_benchmarks( batch_size: int = 64, T: int = 4, hidden_dims: List[int] = [64, 128, 256], device: str = "cuda", ): """Run all benchmarks and compare.""" print("=" * 80) print("LYAPUNOV COMPUTATION SPEEDUP BENCHMARK") print("=" * 80) print(f"Batch size: {batch_size}") print(f"Timesteps: {T}") print(f"Hidden dims: {hidden_dims}") print(f"Device: {device}") print("=" * 80) # Create dummy data x = torch.randn(batch_size, 3, 32, 32, device=device) y = torch.randint(0, 10, (batch_size,), device=device) results = [] # 1. Baseline print("\n[1/6] Benchmarking Baseline...") model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device) results.append(benchmark_model(model, x, y, "Baseline")) del model torch.cuda.empty_cache() # 2. Approach A (batched trajectories) print("[2/6] Benchmarking Approach A (batched)...") model = ApproachA_SNN(hidden_dims=hidden_dims, T=T).to(device) results.append(benchmark_model(model, x, y, "A: Batched trajectories")) del model torch.cuda.empty_cache() # 3. Approach B (global renorm) print("[3/6] Benchmarking Approach B (global renorm)...") model = ApproachB_SNN(hidden_dims=hidden_dims, T=T).to(device) results.append(benchmark_model(model, x, y, "B: Global renorm")) del model torch.cuda.empty_cache() # 4. Approach A+B combined print("[4/6] Benchmarking Approach A+B (combined)...") model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device) results.append(benchmark_model(model, x, y, "A+B: Combined")) del model torch.cuda.empty_cache() # 5. Approach C (torch.compile on baseline) print("[5/6] Benchmarking Approach C (compiled baseline)...") try: model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device) model.forward = torch.compile(model.forward, mode="reduce-overhead") results.append(benchmark_model(model, x, y, "C: Compiled baseline", warmup_iters=10)) del model torch.cuda.empty_cache() except Exception as e: print(f" torch.compile failed: {e}") results.append(BenchmarkResult("C: Compiled baseline", 0, 0, 0, 0, 0)) # 6. A+B+C (all combined) print("[6/6] Benchmarking A+B+C (all optimizations)...") try: model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device) model.forward = torch.compile(model.forward, mode="reduce-overhead") results.append(benchmark_model(model, x, y, "A+B+C: All optimized", warmup_iters=10)) del model torch.cuda.empty_cache() except Exception as e: print(f" torch.compile failed: {e}") results.append(BenchmarkResult("A+B+C: All optimized", 0, 0, 0, 0, 0)) # Print results print("\n" + "=" * 80) print("RESULTS") print("=" * 80) baseline_time = results[0].total_time_ms for r in results: print(r) print("\n" + "-" * 80) print("SPEEDUP vs BASELINE:") print("-" * 80) for r in results[1:]: if r.total_time_ms > 0: speedup = baseline_time / r.total_time_ms print(f" {r.name:<25}: {speedup:.2f}x") # Verify Lyapunov values are consistent print("\n" + "-" * 80) print("LYAPUNOV VALUE CONSISTENCY CHECK:") print("-" * 80) base_lyap = results[0].lyap_value for r in results[1:]: if r.lyap_value != 0: diff = abs(r.lyap_value - base_lyap) status = "✓" if diff < 0.1 else "✗" print(f" {r.name:<25}: λ={r.lyap_value:+.4f} (diff={diff:.4f}) {status}") return results def run_scaling_test(device: str = "cuda"): """Test how approaches scale with batch size and timesteps.""" print("\n" + "=" * 80) print("SCALING TESTS") print("=" * 80) configs = [ {"batch_size": 32, "T": 4, "hidden_dims": [64, 128, 256]}, {"batch_size": 64, "T": 4, "hidden_dims": [64, 128, 256]}, {"batch_size": 128, "T": 4, "hidden_dims": [64, 128, 256]}, {"batch_size": 64, "T": 8, "hidden_dims": [64, 128, 256]}, {"batch_size": 64, "T": 16, "hidden_dims": [64, 128, 256]}, {"batch_size": 64, "T": 4, "hidden_dims": [128, 256, 512]}, # Larger model ] print(f"{'Config':<40} | {'Baseline':<12} | {'A+B':<12} | {'Speedup':<8}") print("-" * 80) for cfg in configs: x = torch.randn(cfg["batch_size"], 3, 32, 32, device=device) y = torch.randint(0, 10, (cfg["batch_size"],), device=device) # Baseline model_base = BaselineSNN(**cfg).to(device) r_base = benchmark_model(model_base, x, y, "base", warmup_iters=3, bench_iters=10) del model_base # A+B model_ab = ApproachAB_SNN(**cfg).to(device) r_ab = benchmark_model(model_ab, x, y, "a+b", warmup_iters=3, bench_iters=10) del model_ab torch.cuda.empty_cache() speedup = r_base.total_time_ms / r_ab.total_time_ms if r_ab.total_time_ms > 0 else 0 cfg_str = f"B={cfg['batch_size']}, T={cfg['T']}, H={cfg['hidden_dims']}" print(f"{cfg_str:<40} | {r_base.total_time_ms:>10.2f}ms | {r_ab.total_time_ms:>10.2f}ms | {speedup:>6.2f}x") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--T", type=int, default=4) parser.add_argument("--hidden_dims", type=int, nargs="+", default=[64, 128, 256]) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--scaling", action="store_true", help="Run scaling tests") args = parser.parse_args() if not torch.cuda.is_available(): print("CUDA not available, using CPU (results will not be representative)") args.device = "cpu" # Main benchmark results = run_benchmarks( batch_size=args.batch_size, T=args.T, hidden_dims=args.hidden_dims, device=args.device, ) # Scaling tests if args.scaling: run_scaling_test(args.device)