""" Benchmark: Diff-only storage vs 2-trajectory storage for Lyapunov computation. Optimization B: Instead of storing two full membrane trajectories: mems[i][0] = base trajectory mems[i][1] = perturbed trajectory Store only: base_mems[i] = base trajectory delta_mems[i] = perturbation (perturbed - base) Benefits: - ~2x less memory for membrane states - Fewer memory reads/writes during renormalization - Better cache utilization """ import os import sys import time import torch import torch.nn as nn from typing import Tuple, Optional, List _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import snntorch as snn from snntorch import surrogate class SpikingVGGBlock(nn.Module): """Conv-BN-LIF block.""" def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None): super().__init__() if spike_grad is None: spike_grad = surrogate.fast_sigmoid(slope=25) self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) self.bn = nn.BatchNorm2d(out_ch) self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) def forward(self, x, mem): h = self.bn(self.conv(x)) spk, mem = self.lif(h, mem) return spk, mem class SpikingVGG_Original(nn.Module): """Original implementation: stores 2 full trajectories with shape (P=2, B, C, H, W).""" def __init__(self, in_channels=3, num_classes=100, base_channels=64, num_stages=3, blocks_per_stage=2, T=4, beta=0.9): super().__init__() self.T = T self.num_stages = num_stages self.blocks_per_stage = blocks_per_stage # Build stages self.stages = nn.ModuleList() self.pools = nn.ModuleList() in_ch = in_channels out_ch = base_channels current_size = 32 # CIFAR for stage in range(num_stages): stage_blocks = nn.ModuleList() for _ in range(blocks_per_stage): stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta)) in_ch = out_ch self.stages.append(stage_blocks) self.pools.append(nn.AvgPool2d(2)) current_size //= 2 if stage < num_stages - 1: out_ch = min(out_ch * 2, 512) self.fc = nn.Linear(in_ch * current_size * current_size, num_classes) self._channel_sizes = self._compute_channel_sizes(base_channels) def _compute_channel_sizes(self, base): sizes = [] ch = base for stage in range(self.num_stages): for _ in range(self.blocks_per_stage): sizes.append(ch) if stage < self.num_stages - 1: ch = min(ch * 2, 512) return sizes def _init_mems(self, batch_size, device, dtype, P=1): mems = [] H, W = 32, 32 for stage in range(self.num_stages): for block_idx in range(self.blocks_per_stage): layer_idx = stage * self.blocks_per_stage + block_idx ch = self._channel_sizes[layer_idx] mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype)) H, W = H // 2, W // 2 return mems def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): B = x.size(0) device, dtype = x.device, x.dtype P = 2 if compute_lyapunov else 1 mems = self._init_mems(B, device, dtype, P=P) if compute_lyapunov: for i in range(len(mems)): mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0]) lyap_accum = torch.zeros(B, device=device, dtype=dtype) spike_sum = None for t in range(self.T): mem_idx = 0 new_mems = [] is_first_block = True for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): for block in stage_blocks: if is_first_block: h_conv = block.bn(block.conv(x)) h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) h_flat = h.reshape(P * B, *h.shape[2:]) mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) spk_flat, mem_new_flat = block.lif(h_flat, mem_flat) spk = spk_flat.view(P, B, *spk_flat.shape[1:]) mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) h = spk new_mems.append(mem_new) is_first_block = False else: h_flat = h.reshape(P * B, *h.shape[2:]) mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) h_conv = block.bn(block.conv(h_flat)) spk_flat, mem_new_flat = block.lif(h_conv, mem_flat) spk = spk_flat.view(P, B, *spk_flat.shape[1:]) mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) h = spk new_mems.append(mem_new) mem_idx += 1 h_flat = h.reshape(P * B, *h.shape[2:]) h_pooled = pool(h_flat) h = h_pooled.view(P, B, *h_pooled.shape[1:]) mems = new_mems h_orig = h[0].view(B, -1) if spike_sum is None: spike_sum = h_orig else: spike_sum = spike_sum + h_orig if compute_lyapunov: delta_sq = torch.zeros(B, device=device, dtype=dtype) for i in range(len(new_mems)): diff = new_mems[i][1] - new_mems[i][0] delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3)) delta = torch.sqrt(delta_sq + 1e-12) lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) scale = (lyap_eps / delta).view(B, 1, 1, 1) for i in range(len(new_mems)): diff = new_mems[i][1] - new_mems[i][0] mems[i] = torch.stack([ new_mems[i][0], new_mems[i][0] + diff * scale ], dim=0) logits = self.fc(spike_sum) lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None return logits, lyap_est class SpikingVGG_DiffOnly(nn.Module): """ Optimized implementation: stores base + diff instead of 2 full trajectories. Memory layout: base_mems[i]: (B, C, H, W) - base trajectory membrane delta_mems[i]: (B, C, H, W) - perturbation vector Perturbed trajectory is materialized as (base + delta) only when needed. """ def __init__(self, in_channels=3, num_classes=100, base_channels=64, num_stages=3, blocks_per_stage=2, T=4, beta=0.9): super().__init__() self.T = T self.num_stages = num_stages self.blocks_per_stage = blocks_per_stage self.stages = nn.ModuleList() self.pools = nn.ModuleList() in_ch = in_channels out_ch = base_channels current_size = 32 for stage in range(num_stages): stage_blocks = nn.ModuleList() for _ in range(blocks_per_stage): stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta)) in_ch = out_ch self.stages.append(stage_blocks) self.pools.append(nn.AvgPool2d(2)) current_size //= 2 if stage < num_stages - 1: out_ch = min(out_ch * 2, 512) self.fc = nn.Linear(in_ch * current_size * current_size, num_classes) self._channel_sizes = self._compute_channel_sizes(base_channels) def _compute_channel_sizes(self, base): sizes = [] ch = base for stage in range(self.num_stages): for _ in range(self.blocks_per_stage): sizes.append(ch) if stage < self.num_stages - 1: ch = min(ch * 2, 512) return sizes def _init_mems(self, batch_size, device, dtype): """Initialize base membrane states (B, C, H, W).""" base_mems = [] H, W = 32, 32 for stage in range(self.num_stages): for block_idx in range(self.blocks_per_stage): layer_idx = stage * self.blocks_per_stage + block_idx ch = self._channel_sizes[layer_idx] base_mems.append(torch.zeros(batch_size, ch, H, W, device=device, dtype=dtype)) H, W = H // 2, W // 2 return base_mems def _init_deltas(self, base_mems, lyap_eps): """Initialize perturbation vectors δ with ||δ||_global = eps.""" delta_mems = [] for base in base_mems: delta_mems.append(lyap_eps * torch.randn_like(base)) return delta_mems def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): B = x.size(0) device, dtype = x.device, x.dtype # Initialize base membrane states base_mems = self._init_mems(B, device, dtype) # Initialize perturbations if computing Lyapunov if compute_lyapunov: delta_mems = self._init_deltas(base_mems, lyap_eps) lyap_accum = torch.zeros(B, device=device, dtype=dtype) else: delta_mems = None spike_sum = None for t in range(self.T): mem_idx = 0 new_base_mems = [] new_delta_mems = [] if compute_lyapunov else None # Track spikes for base and perturbed (if computing Lyapunov) h_base = None h_delta = None # Will store (h_perturbed - h_base) is_first_block = True for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): for block in stage_blocks: if is_first_block: # First block: input x is same for both trajectories h_conv = block.bn(block.conv(x)) # (B, C, H, W) # Base trajectory spk_base, mem_base_new = block.lif(h_conv, base_mems[mem_idx]) new_base_mems.append(mem_base_new) h_base = spk_base if compute_lyapunov: # Perturbed trajectory: mem = base + delta mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx] spk_perturbed, mem_perturbed_new = block.lif(h_conv, mem_perturbed) # Store delta for new membrane new_delta_mems.append(mem_perturbed_new - mem_base_new) # Store spike difference for propagation h_delta = spk_perturbed - spk_base is_first_block = False else: # Subsequent blocks: inputs differ # Base trajectory h_conv_base = block.bn(block.conv(h_base)) spk_base, mem_base_new = block.lif(h_conv_base, base_mems[mem_idx]) new_base_mems.append(mem_base_new) if compute_lyapunov: # Perturbed trajectory: h_perturbed = h_base + h_delta h_perturbed = h_base + h_delta h_conv_perturbed = block.bn(block.conv(h_perturbed)) mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx] spk_perturbed, mem_perturbed_new = block.lif(h_conv_perturbed, mem_perturbed) new_delta_mems.append(mem_perturbed_new - mem_base_new) h_delta = spk_perturbed - spk_base h_base = spk_base mem_idx += 1 # Pooling h_base = pool(h_base) if compute_lyapunov: # Pool both and compute new delta h_perturbed = h_base + pool(h_delta) # Note: pool(base+delta) ≠ pool(base) + pool(delta) in general # But for AvgPool, it's linear so this is fine h_delta = h_perturbed - h_base # This simplifies to pool(h_delta) for AvgPool h_delta = pool(h_delta) # Actually just pool the delta directly (AvgPool is linear) # Update membrane states base_mems = new_base_mems # Accumulate spikes from base trajectory h_flat = h_base.view(B, -1) if spike_sum is None: spike_sum = h_flat else: spike_sum = spike_sum + h_flat # Lyapunov: compute global divergence and renormalize if compute_lyapunov: # Global norm of all deltas: ||δ||² = Σ_layers ||δ_layer||² delta_sq = torch.zeros(B, device=device, dtype=dtype) for delta in new_delta_mems: delta_sq = delta_sq + (delta ** 2).sum(dim=(1, 2, 3)) delta_norm = torch.sqrt(delta_sq + 1e-12) lyap_accum = lyap_accum + torch.log(delta_norm / lyap_eps + 1e-12) # Renormalize: scale all deltas so ||δ||_global = eps scale = (lyap_eps / delta_norm).view(B, 1, 1, 1) delta_mems = [delta * scale for delta in new_delta_mems] logits = self.fc(spike_sum) lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None return logits, lyap_est def count_parameters(model): return sum(p.numel() for p in model.parameters()) def benchmark_forward(model, x, compute_lyapunov, num_warmup=5, num_runs=20): """Benchmark forward pass time.""" device = x.device # Warmup for _ in range(num_warmup): with torch.no_grad(): _ = model(x, compute_lyapunov=compute_lyapunov) torch.cuda.synchronize() # Timed runs times = [] for _ in range(num_runs): torch.cuda.synchronize() start = time.perf_counter() logits, lyap = model(x, compute_lyapunov=compute_lyapunov) torch.cuda.synchronize() end = time.perf_counter() times.append(end - start) return times, lyap def benchmark_forward_backward(model, x, y, criterion, compute_lyapunov, lambda_reg=0.3, num_warmup=5, num_runs=20): """Benchmark forward + backward pass time.""" device = x.device # Warmup for _ in range(num_warmup): model.zero_grad() logits, lyap = model(x, compute_lyapunov=compute_lyapunov) loss = criterion(logits, y) if compute_lyapunov and lyap is not None: loss = loss + lambda_reg * (lyap ** 2) loss.backward() torch.cuda.synchronize() # Timed runs times = [] for _ in range(num_runs): model.zero_grad() torch.cuda.synchronize() start = time.perf_counter() logits, lyap = model(x, compute_lyapunov=compute_lyapunov) loss = criterion(logits, y) if compute_lyapunov and lyap is not None: loss = loss + lambda_reg * (lyap ** 2) loss.backward() torch.cuda.synchronize() end = time.perf_counter() times.append(end - start) return times def measure_memory(model, x, compute_lyapunov): """Measure peak GPU memory during forward pass.""" torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() with torch.no_grad(): _ = model(x, compute_lyapunov=compute_lyapunov) torch.cuda.synchronize() peak_mem = torch.cuda.max_memory_allocated() / 1024**2 # MB return peak_mem def run_benchmark(): print("=" * 70) print("LYAPUNOV COMPUTATION BENCHMARK: Original vs Diff-Only Storage") print("=" * 70) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") if device.type == "cuda": print(f"GPU: {torch.cuda.get_device_name()}") # Test configurations configs = [ {"depth": 4, "blocks_per_stage": 1, "batch_size": 64}, {"depth": 8, "blocks_per_stage": 2, "batch_size": 64}, {"depth": 12, "blocks_per_stage": 3, "batch_size": 32}, ] print("\n" + "=" * 70) for cfg in configs: depth = cfg["depth"] blocks = cfg["blocks_per_stage"] batch_size = cfg["batch_size"] print(f"\n{'='*70}") print(f"DEPTH = {depth} ({blocks} blocks/stage), Batch = {batch_size}") print(f"{'='*70}") # Create models model_orig = SpikingVGG_Original( blocks_per_stage=blocks, T=4 ).to(device) model_diff = SpikingVGG_DiffOnly( blocks_per_stage=blocks, T=4 ).to(device) # Copy weights from original to diff-only model_diff.load_state_dict(model_orig.state_dict()) print(f"Parameters: {count_parameters(model_orig):,}") # Create input x = torch.randn(batch_size, 3, 32, 32, device=device) y = torch.randint(0, 100, (batch_size,), device=device) criterion = nn.CrossEntropyLoss() # ============================================================ # Test 1: Verify outputs match # ============================================================ print("\n--- Output Verification ---") model_orig.eval() model_diff.eval() torch.manual_seed(42) with torch.no_grad(): logits_orig, lyap_orig = model_orig(x, compute_lyapunov=True, lyap_eps=1e-4) torch.manual_seed(42) with torch.no_grad(): logits_diff, lyap_diff = model_diff(x, compute_lyapunov=True, lyap_eps=1e-4) logits_match = torch.allclose(logits_orig, logits_diff, rtol=1e-4, atol=1e-5) lyap_close = abs(lyap_orig.item() - lyap_diff.item()) < 0.1 # Allow some difference due to different implementations print(f"Logits match: {logits_match}") print(f"Lyapunov - Original: {lyap_orig.item():.4f}, Diff-only: {lyap_diff.item():.4f}") print(f"Lyapunov close (within 0.1): {lyap_close}") # ============================================================ # Test 2: Forward-only speed (no grad) # ============================================================ print("\n--- Forward Speed (no_grad) ---") model_orig.eval() model_diff.eval() # Without Lyapunov times_orig_noly, _ = benchmark_forward(model_orig, x, compute_lyapunov=False) times_diff_noly, _ = benchmark_forward(model_diff, x, compute_lyapunov=False) mean_orig = sum(times_orig_noly) / len(times_orig_noly) * 1000 mean_diff = sum(times_diff_noly) / len(times_diff_noly) * 1000 print(f" Without Lyapunov:") print(f" Original: {mean_orig:.2f} ms") print(f" Diff-only: {mean_diff:.2f} ms") # With Lyapunov times_orig_ly, _ = benchmark_forward(model_orig, x, compute_lyapunov=True) times_diff_ly, _ = benchmark_forward(model_diff, x, compute_lyapunov=True) mean_orig_ly = sum(times_orig_ly) / len(times_orig_ly) * 1000 mean_diff_ly = sum(times_diff_ly) / len(times_diff_ly) * 1000 speedup = mean_orig_ly / mean_diff_ly print(f" With Lyapunov:") print(f" Original: {mean_orig_ly:.2f} ms") print(f" Diff-only: {mean_diff_ly:.2f} ms") print(f" Speedup: {speedup:.2f}x") # ============================================================ # Test 3: Forward + Backward speed (training mode) # ============================================================ print("\n--- Forward+Backward Speed (training) ---") model_orig.train() model_diff.train() times_orig_train = benchmark_forward_backward( model_orig, x, y, criterion, compute_lyapunov=True ) times_diff_train = benchmark_forward_backward( model_diff, x, y, criterion, compute_lyapunov=True ) mean_orig_train = sum(times_orig_train) / len(times_orig_train) * 1000 mean_diff_train = sum(times_diff_train) / len(times_diff_train) * 1000 speedup_train = mean_orig_train / mean_diff_train print(f" With Lyapunov + backward:") print(f" Original: {mean_orig_train:.2f} ms") print(f" Diff-only: {mean_diff_train:.2f} ms") print(f" Speedup: {speedup_train:.2f}x") # ============================================================ # Test 4: Memory usage # ============================================================ if device.type == "cuda": print("\n--- Peak GPU Memory ---") mem_orig_noly = measure_memory(model_orig, x, compute_lyapunov=False) mem_diff_noly = measure_memory(model_diff, x, compute_lyapunov=False) mem_orig_ly = measure_memory(model_orig, x, compute_lyapunov=True) mem_diff_ly = measure_memory(model_diff, x, compute_lyapunov=True) print(f" Without Lyapunov:") print(f" Original: {mem_orig_noly:.1f} MB") print(f" Diff-only: {mem_diff_noly:.1f} MB") print(f" With Lyapunov:") print(f" Original: {mem_orig_ly:.1f} MB") print(f" Diff-only: {mem_diff_ly:.1f} MB") print(f" Memory saved: {mem_orig_ly - mem_diff_ly:.1f} MB ({100*(mem_orig_ly - mem_diff_ly)/mem_orig_ly:.1f}%)") # Cleanup del model_orig, model_diff, x, y torch.cuda.empty_cache() print("\n" + "=" * 70) print("BENCHMARK COMPLETE") print("=" * 70) if __name__ == "__main__": run_benchmark()