summaryrefslogtreecommitdiff
path: root/files/experiments/lyapunov_diffonly_benchmark.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/experiments/lyapunov_diffonly_benchmark.py
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files/experiments/lyapunov_diffonly_benchmark.py')
-rw-r--r--files/experiments/lyapunov_diffonly_benchmark.py590
1 files changed, 590 insertions, 0 deletions
diff --git a/files/experiments/lyapunov_diffonly_benchmark.py b/files/experiments/lyapunov_diffonly_benchmark.py
new file mode 100644
index 0000000..05dbcd2
--- /dev/null
+++ b/files/experiments/lyapunov_diffonly_benchmark.py
@@ -0,0 +1,590 @@
+"""
+Benchmark: Diff-only storage vs 2-trajectory storage for Lyapunov computation.
+
+Optimization B: Instead of storing two full membrane trajectories:
+ mems[i][0] = base trajectory
+ mems[i][1] = perturbed trajectory
+
+Store only:
+ base_mems[i] = base trajectory
+ delta_mems[i] = perturbation (perturbed - base)
+
+Benefits:
+ - ~2x less memory for membrane states
+ - Fewer memory reads/writes during renormalization
+ - Better cache utilization
+"""
+
+import os
+import sys
+import time
+import torch
+import torch.nn as nn
+from typing import Tuple, Optional, List
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+class SpikingVGGBlock(nn.Module):
+ """Conv-BN-LIF block."""
+
+ def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None):
+ super().__init__()
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+
+ def forward(self, x, mem):
+ h = self.bn(self.conv(x))
+ spk, mem = self.lif(h, mem)
+ return spk, mem
+
+
+class SpikingVGG_Original(nn.Module):
+ """Original implementation: stores 2 full trajectories with shape (P=2, B, C, H, W)."""
+
+ def __init__(self, in_channels=3, num_classes=100, base_channels=64,
+ num_stages=3, blocks_per_stage=2, T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+
+ # Build stages
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ in_ch = in_channels
+ out_ch = base_channels
+ current_size = 32 # CIFAR
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for _ in range(blocks_per_stage):
+ stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta))
+ in_ch = out_ch
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ current_size //= 2
+ if stage < num_stages - 1:
+ out_ch = min(out_ch * 2, 512)
+
+ self.fc = nn.Linear(in_ch * current_size * current_size, num_classes)
+ self._channel_sizes = self._compute_channel_sizes(base_channels)
+
+ def _compute_channel_sizes(self, base):
+ sizes = []
+ ch = base
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ sizes.append(ch)
+ if stage < self.num_stages - 1:
+ ch = min(ch * 2, 512)
+ return sizes
+
+ def _init_mems(self, batch_size, device, dtype, P=1):
+ mems = []
+ H, W = 32, 32
+ for stage in range(self.num_stages):
+ for block_idx in range(self.blocks_per_stage):
+ layer_idx = stage * self.blocks_per_stage + block_idx
+ ch = self._channel_sizes[layer_idx]
+ mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ return mems
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ P = 2 if compute_lyapunov else 1
+
+ mems = self._init_mems(B, device, dtype, P=P)
+
+ if compute_lyapunov:
+ for i in range(len(mems)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0])
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = None
+
+ for t in range(self.T):
+ mem_idx = 0
+ new_mems = []
+ is_first_block = True
+
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ h_conv = block.bn(block.conv(x))
+ h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ spk_flat, mem_new_flat = block.lif(h_flat, mem_flat)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+ h = spk
+ new_mems.append(mem_new)
+ is_first_block = False
+ else:
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ h_conv = block.bn(block.conv(h_flat))
+ spk_flat, mem_new_flat = block.lif(h_conv, mem_flat)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+ h = spk
+ new_mems.append(mem_new)
+ mem_idx += 1
+
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ h_pooled = pool(h_flat)
+ h = h_pooled.view(P, B, *h_pooled.shape[1:])
+
+ mems = new_mems
+
+ h_orig = h[0].view(B, -1)
+ if spike_sum is None:
+ spike_sum = h_orig
+ else:
+ spike_sum = spike_sum + h_orig
+
+ if compute_lyapunov:
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ scale = (lyap_eps / delta).view(B, 1, 1, 1)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ mems[i] = torch.stack([
+ new_mems[i][0],
+ new_mems[i][0] + diff * scale
+ ], dim=0)
+
+ logits = self.fc(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+class SpikingVGG_DiffOnly(nn.Module):
+ """
+ Optimized implementation: stores base + diff instead of 2 full trajectories.
+
+ Memory layout:
+ base_mems[i]: (B, C, H, W) - base trajectory membrane
+ delta_mems[i]: (B, C, H, W) - perturbation vector
+
+ Perturbed trajectory is materialized as (base + delta) only when needed.
+ """
+
+ def __init__(self, in_channels=3, num_classes=100, base_channels=64,
+ num_stages=3, blocks_per_stage=2, T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ in_ch = in_channels
+ out_ch = base_channels
+ current_size = 32
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for _ in range(blocks_per_stage):
+ stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta))
+ in_ch = out_ch
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ current_size //= 2
+ if stage < num_stages - 1:
+ out_ch = min(out_ch * 2, 512)
+
+ self.fc = nn.Linear(in_ch * current_size * current_size, num_classes)
+ self._channel_sizes = self._compute_channel_sizes(base_channels)
+
+ def _compute_channel_sizes(self, base):
+ sizes = []
+ ch = base
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ sizes.append(ch)
+ if stage < self.num_stages - 1:
+ ch = min(ch * 2, 512)
+ return sizes
+
+ def _init_mems(self, batch_size, device, dtype):
+ """Initialize base membrane states (B, C, H, W)."""
+ base_mems = []
+ H, W = 32, 32
+ for stage in range(self.num_stages):
+ for block_idx in range(self.blocks_per_stage):
+ layer_idx = stage * self.blocks_per_stage + block_idx
+ ch = self._channel_sizes[layer_idx]
+ base_mems.append(torch.zeros(batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ return base_mems
+
+ def _init_deltas(self, base_mems, lyap_eps):
+ """Initialize perturbation vectors δ with ||δ||_global = eps."""
+ delta_mems = []
+ for base in base_mems:
+ delta_mems.append(lyap_eps * torch.randn_like(base))
+ return delta_mems
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+
+ # Initialize base membrane states
+ base_mems = self._init_mems(B, device, dtype)
+
+ # Initialize perturbations if computing Lyapunov
+ if compute_lyapunov:
+ delta_mems = self._init_deltas(base_mems, lyap_eps)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+ else:
+ delta_mems = None
+
+ spike_sum = None
+
+ for t in range(self.T):
+ mem_idx = 0
+ new_base_mems = []
+ new_delta_mems = [] if compute_lyapunov else None
+
+ # Track spikes for base and perturbed (if computing Lyapunov)
+ h_base = None
+ h_delta = None # Will store (h_perturbed - h_base)
+ is_first_block = True
+
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ # First block: input x is same for both trajectories
+ h_conv = block.bn(block.conv(x)) # (B, C, H, W)
+
+ # Base trajectory
+ spk_base, mem_base_new = block.lif(h_conv, base_mems[mem_idx])
+ new_base_mems.append(mem_base_new)
+ h_base = spk_base
+
+ if compute_lyapunov:
+ # Perturbed trajectory: mem = base + delta
+ mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx]
+ spk_perturbed, mem_perturbed_new = block.lif(h_conv, mem_perturbed)
+ # Store delta for new membrane
+ new_delta_mems.append(mem_perturbed_new - mem_base_new)
+ # Store spike difference for propagation
+ h_delta = spk_perturbed - spk_base
+
+ is_first_block = False
+ else:
+ # Subsequent blocks: inputs differ
+ # Base trajectory
+ h_conv_base = block.bn(block.conv(h_base))
+ spk_base, mem_base_new = block.lif(h_conv_base, base_mems[mem_idx])
+ new_base_mems.append(mem_base_new)
+
+ if compute_lyapunov:
+ # Perturbed trajectory: h_perturbed = h_base + h_delta
+ h_perturbed = h_base + h_delta
+ h_conv_perturbed = block.bn(block.conv(h_perturbed))
+ mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx]
+ spk_perturbed, mem_perturbed_new = block.lif(h_conv_perturbed, mem_perturbed)
+ new_delta_mems.append(mem_perturbed_new - mem_base_new)
+ h_delta = spk_perturbed - spk_base
+
+ h_base = spk_base
+
+ mem_idx += 1
+
+ # Pooling
+ h_base = pool(h_base)
+ if compute_lyapunov:
+ # Pool both and compute new delta
+ h_perturbed = h_base + pool(h_delta) # Note: pool(base+delta) ≠ pool(base) + pool(delta) in general
+ # But for AvgPool, it's linear so this is fine
+ h_delta = h_perturbed - h_base # This simplifies to pool(h_delta) for AvgPool
+ h_delta = pool(h_delta) # Actually just pool the delta directly (AvgPool is linear)
+
+ # Update membrane states
+ base_mems = new_base_mems
+
+ # Accumulate spikes from base trajectory
+ h_flat = h_base.view(B, -1)
+ if spike_sum is None:
+ spike_sum = h_flat
+ else:
+ spike_sum = spike_sum + h_flat
+
+ # Lyapunov: compute global divergence and renormalize
+ if compute_lyapunov:
+ # Global norm of all deltas: ||δ||² = Σ_layers ||δ_layer||²
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for delta in new_delta_mems:
+ delta_sq = delta_sq + (delta ** 2).sum(dim=(1, 2, 3))
+
+ delta_norm = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta_norm / lyap_eps + 1e-12)
+
+ # Renormalize: scale all deltas so ||δ||_global = eps
+ scale = (lyap_eps / delta_norm).view(B, 1, 1, 1)
+ delta_mems = [delta * scale for delta in new_delta_mems]
+
+ logits = self.fc(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters())
+
+
+def benchmark_forward(model, x, compute_lyapunov, num_warmup=5, num_runs=20):
+ """Benchmark forward pass time."""
+ device = x.device
+
+ # Warmup
+ for _ in range(num_warmup):
+ with torch.no_grad():
+ _ = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+
+ # Timed runs
+ times = []
+ for _ in range(num_runs):
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ times.append(end - start)
+
+ return times, lyap
+
+
+def benchmark_forward_backward(model, x, y, criterion, compute_lyapunov,
+ lambda_reg=0.3, num_warmup=5, num_runs=20):
+ """Benchmark forward + backward pass time."""
+ device = x.device
+
+ # Warmup
+ for _ in range(num_warmup):
+ model.zero_grad()
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+ loss = criterion(logits, y)
+ if compute_lyapunov and lyap is not None:
+ loss = loss + lambda_reg * (lyap ** 2)
+ loss.backward()
+
+ torch.cuda.synchronize()
+
+ # Timed runs
+ times = []
+ for _ in range(num_runs):
+ model.zero_grad()
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=compute_lyapunov)
+ loss = criterion(logits, y)
+ if compute_lyapunov and lyap is not None:
+ loss = loss + lambda_reg * (lyap ** 2)
+ loss.backward()
+
+ torch.cuda.synchronize()
+ end = time.perf_counter()
+ times.append(end - start)
+
+ return times
+
+
+def measure_memory(model, x, compute_lyapunov):
+ """Measure peak GPU memory during forward pass."""
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.synchronize()
+
+ with torch.no_grad():
+ _ = model(x, compute_lyapunov=compute_lyapunov)
+
+ torch.cuda.synchronize()
+ peak_mem = torch.cuda.max_memory_allocated() / 1024**2 # MB
+ return peak_mem
+
+
+def run_benchmark():
+ print("=" * 70)
+ print("LYAPUNOV COMPUTATION BENCHMARK: Original vs Diff-Only Storage")
+ print("=" * 70)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+
+ if device.type == "cuda":
+ print(f"GPU: {torch.cuda.get_device_name()}")
+
+ # Test configurations
+ configs = [
+ {"depth": 4, "blocks_per_stage": 1, "batch_size": 64},
+ {"depth": 8, "blocks_per_stage": 2, "batch_size": 64},
+ {"depth": 12, "blocks_per_stage": 3, "batch_size": 32},
+ ]
+
+ print("\n" + "=" * 70)
+
+ for cfg in configs:
+ depth = cfg["depth"]
+ blocks = cfg["blocks_per_stage"]
+ batch_size = cfg["batch_size"]
+
+ print(f"\n{'='*70}")
+ print(f"DEPTH = {depth} ({blocks} blocks/stage), Batch = {batch_size}")
+ print(f"{'='*70}")
+
+ # Create models
+ model_orig = SpikingVGG_Original(
+ blocks_per_stage=blocks, T=4
+ ).to(device)
+
+ model_diff = SpikingVGG_DiffOnly(
+ blocks_per_stage=blocks, T=4
+ ).to(device)
+
+ # Copy weights from original to diff-only
+ model_diff.load_state_dict(model_orig.state_dict())
+
+ print(f"Parameters: {count_parameters(model_orig):,}")
+
+ # Create input
+ x = torch.randn(batch_size, 3, 32, 32, device=device)
+ y = torch.randint(0, 100, (batch_size,), device=device)
+ criterion = nn.CrossEntropyLoss()
+
+ # ============================================================
+ # Test 1: Verify outputs match
+ # ============================================================
+ print("\n--- Output Verification ---")
+ model_orig.eval()
+ model_diff.eval()
+
+ torch.manual_seed(42)
+ with torch.no_grad():
+ logits_orig, lyap_orig = model_orig(x, compute_lyapunov=True, lyap_eps=1e-4)
+
+ torch.manual_seed(42)
+ with torch.no_grad():
+ logits_diff, lyap_diff = model_diff(x, compute_lyapunov=True, lyap_eps=1e-4)
+
+ logits_match = torch.allclose(logits_orig, logits_diff, rtol=1e-4, atol=1e-5)
+ lyap_close = abs(lyap_orig.item() - lyap_diff.item()) < 0.1 # Allow some difference due to different implementations
+
+ print(f"Logits match: {logits_match}")
+ print(f"Lyapunov - Original: {lyap_orig.item():.4f}, Diff-only: {lyap_diff.item():.4f}")
+ print(f"Lyapunov close (within 0.1): {lyap_close}")
+
+ # ============================================================
+ # Test 2: Forward-only speed (no grad)
+ # ============================================================
+ print("\n--- Forward Speed (no_grad) ---")
+ model_orig.eval()
+ model_diff.eval()
+
+ # Without Lyapunov
+ times_orig_noly, _ = benchmark_forward(model_orig, x, compute_lyapunov=False)
+ times_diff_noly, _ = benchmark_forward(model_diff, x, compute_lyapunov=False)
+
+ mean_orig = sum(times_orig_noly) / len(times_orig_noly) * 1000
+ mean_diff = sum(times_diff_noly) / len(times_diff_noly) * 1000
+
+ print(f" Without Lyapunov:")
+ print(f" Original: {mean_orig:.2f} ms")
+ print(f" Diff-only: {mean_diff:.2f} ms")
+
+ # With Lyapunov
+ times_orig_ly, _ = benchmark_forward(model_orig, x, compute_lyapunov=True)
+ times_diff_ly, _ = benchmark_forward(model_diff, x, compute_lyapunov=True)
+
+ mean_orig_ly = sum(times_orig_ly) / len(times_orig_ly) * 1000
+ mean_diff_ly = sum(times_diff_ly) / len(times_diff_ly) * 1000
+ speedup = mean_orig_ly / mean_diff_ly
+
+ print(f" With Lyapunov:")
+ print(f" Original: {mean_orig_ly:.2f} ms")
+ print(f" Diff-only: {mean_diff_ly:.2f} ms")
+ print(f" Speedup: {speedup:.2f}x")
+
+ # ============================================================
+ # Test 3: Forward + Backward speed (training mode)
+ # ============================================================
+ print("\n--- Forward+Backward Speed (training) ---")
+ model_orig.train()
+ model_diff.train()
+
+ times_orig_train = benchmark_forward_backward(
+ model_orig, x, y, criterion, compute_lyapunov=True
+ )
+ times_diff_train = benchmark_forward_backward(
+ model_diff, x, y, criterion, compute_lyapunov=True
+ )
+
+ mean_orig_train = sum(times_orig_train) / len(times_orig_train) * 1000
+ mean_diff_train = sum(times_diff_train) / len(times_diff_train) * 1000
+ speedup_train = mean_orig_train / mean_diff_train
+
+ print(f" With Lyapunov + backward:")
+ print(f" Original: {mean_orig_train:.2f} ms")
+ print(f" Diff-only: {mean_diff_train:.2f} ms")
+ print(f" Speedup: {speedup_train:.2f}x")
+
+ # ============================================================
+ # Test 4: Memory usage
+ # ============================================================
+ if device.type == "cuda":
+ print("\n--- Peak GPU Memory ---")
+
+ mem_orig_noly = measure_memory(model_orig, x, compute_lyapunov=False)
+ mem_diff_noly = measure_memory(model_diff, x, compute_lyapunov=False)
+
+ mem_orig_ly = measure_memory(model_orig, x, compute_lyapunov=True)
+ mem_diff_ly = measure_memory(model_diff, x, compute_lyapunov=True)
+
+ print(f" Without Lyapunov:")
+ print(f" Original: {mem_orig_noly:.1f} MB")
+ print(f" Diff-only: {mem_diff_noly:.1f} MB")
+ print(f" With Lyapunov:")
+ print(f" Original: {mem_orig_ly:.1f} MB")
+ print(f" Diff-only: {mem_diff_ly:.1f} MB")
+ print(f" Memory saved: {mem_orig_ly - mem_diff_ly:.1f} MB ({100*(mem_orig_ly - mem_diff_ly)/mem_orig_ly:.1f}%)")
+
+ # Cleanup
+ del model_orig, model_diff, x, y
+ torch.cuda.empty_cache()
+
+ print("\n" + "=" * 70)
+ print("BENCHMARK COMPLETE")
+ print("=" * 70)
+
+
+if __name__ == "__main__":
+ run_benchmark()