diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/experiments/lyapunov_diffonly_benchmark.py | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files/experiments/lyapunov_diffonly_benchmark.py')
| -rw-r--r-- | files/experiments/lyapunov_diffonly_benchmark.py | 590 |
1 files changed, 590 insertions, 0 deletions
diff --git a/files/experiments/lyapunov_diffonly_benchmark.py b/files/experiments/lyapunov_diffonly_benchmark.py new file mode 100644 index 0000000..05dbcd2 --- /dev/null +++ b/files/experiments/lyapunov_diffonly_benchmark.py @@ -0,0 +1,590 @@ +""" +Benchmark: Diff-only storage vs 2-trajectory storage for Lyapunov computation. + +Optimization B: Instead of storing two full membrane trajectories: + mems[i][0] = base trajectory + mems[i][1] = perturbed trajectory + +Store only: + base_mems[i] = base trajectory + delta_mems[i] = perturbation (perturbed - base) + +Benefits: + - ~2x less memory for membrane states + - Fewer memory reads/writes during renormalization + - Better cache utilization +""" + +import os +import sys +import time +import torch +import torch.nn as nn +from typing import Tuple, Optional, List + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import snntorch as snn +from snntorch import surrogate + + +class SpikingVGGBlock(nn.Module): + """Conv-BN-LIF block.""" + + def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None): + super().__init__() + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_ch) + self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + + def forward(self, x, mem): + h = self.bn(self.conv(x)) + spk, mem = self.lif(h, mem) + return spk, mem + + +class SpikingVGG_Original(nn.Module): + """Original implementation: stores 2 full trajectories with shape (P=2, B, C, H, W).""" + + def __init__(self, in_channels=3, num_classes=100, base_channels=64, + num_stages=3, blocks_per_stage=2, T=4, beta=0.9): + super().__init__() + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + + # Build stages + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + in_ch = in_channels + out_ch = base_channels + current_size = 32 # CIFAR + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for _ in range(blocks_per_stage): + stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta)) + in_ch = out_ch + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + current_size //= 2 + if stage < num_stages - 1: + out_ch = min(out_ch * 2, 512) + + self.fc = nn.Linear(in_ch * current_size * current_size, num_classes) + self._channel_sizes = self._compute_channel_sizes(base_channels) + + def _compute_channel_sizes(self, base): + sizes = [] + ch = base + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + sizes.append(ch) + if stage < self.num_stages - 1: + ch = min(ch * 2, 512) + return sizes + + def _init_mems(self, batch_size, device, dtype, P=1): + mems = [] + H, W = 32, 32 + for stage in range(self.num_stages): + for block_idx in range(self.blocks_per_stage): + layer_idx = stage * self.blocks_per_stage + block_idx + ch = self._channel_sizes[layer_idx] + mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + return mems + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + P = 2 if compute_lyapunov else 1 + + mems = self._init_mems(B, device, dtype, P=P) + + if compute_lyapunov: + for i in range(len(mems)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0]) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = None + + for t in range(self.T): + mem_idx = 0 + new_mems = [] + is_first_block = True + + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + h_conv = block.bn(block.conv(x)) + h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + spk_flat, mem_new_flat = block.lif(h_flat, mem_flat) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + h = spk + new_mems.append(mem_new) + is_first_block = False + else: + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + h_conv = block.bn(block.conv(h_flat)) + spk_flat, mem_new_flat = block.lif(h_conv, mem_flat) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + h = spk + new_mems.append(mem_new) + mem_idx += 1 + + h_flat = h.reshape(P * B, *h.shape[2:]) + h_pooled = pool(h_flat) + h = h_pooled.view(P, B, *h_pooled.shape[1:]) + + mems = new_mems + + h_orig = h[0].view(B, -1) + if spike_sum is None: + spike_sum = h_orig + else: + spike_sum = spike_sum + h_orig + + if compute_lyapunov: + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + scale = (lyap_eps / delta).view(B, 1, 1, 1) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + mems[i] = torch.stack([ + new_mems[i][0], + new_mems[i][0] + diff * scale + ], dim=0) + + logits = self.fc(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +class SpikingVGG_DiffOnly(nn.Module): + """ + Optimized implementation: stores base + diff instead of 2 full trajectories. + + Memory layout: + base_mems[i]: (B, C, H, W) - base trajectory membrane + delta_mems[i]: (B, C, H, W) - perturbation vector + + Perturbed trajectory is materialized as (base + delta) only when needed. + """ + + def __init__(self, in_channels=3, num_classes=100, base_channels=64, + num_stages=3, blocks_per_stage=2, T=4, beta=0.9): + super().__init__() + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + in_ch = in_channels + out_ch = base_channels + current_size = 32 + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for _ in range(blocks_per_stage): + stage_blocks.append(SpikingVGGBlock(in_ch, out_ch, beta=beta)) + in_ch = out_ch + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + current_size //= 2 + if stage < num_stages - 1: + out_ch = min(out_ch * 2, 512) + + self.fc = nn.Linear(in_ch * current_size * current_size, num_classes) + self._channel_sizes = self._compute_channel_sizes(base_channels) + + def _compute_channel_sizes(self, base): + sizes = [] + ch = base + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + sizes.append(ch) + if stage < self.num_stages - 1: + ch = min(ch * 2, 512) + return sizes + + def _init_mems(self, batch_size, device, dtype): + """Initialize base membrane states (B, C, H, W).""" + base_mems = [] + H, W = 32, 32 + for stage in range(self.num_stages): + for block_idx in range(self.blocks_per_stage): + layer_idx = stage * self.blocks_per_stage + block_idx + ch = self._channel_sizes[layer_idx] + base_mems.append(torch.zeros(batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + return base_mems + + def _init_deltas(self, base_mems, lyap_eps): + """Initialize perturbation vectors δ with ||δ||_global = eps.""" + delta_mems = [] + for base in base_mems: + delta_mems.append(lyap_eps * torch.randn_like(base)) + return delta_mems + + def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4): + B = x.size(0) + device, dtype = x.device, x.dtype + + # Initialize base membrane states + base_mems = self._init_mems(B, device, dtype) + + # Initialize perturbations if computing Lyapunov + if compute_lyapunov: + delta_mems = self._init_deltas(base_mems, lyap_eps) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + else: + delta_mems = None + + spike_sum = None + + for t in range(self.T): + mem_idx = 0 + new_base_mems = [] + new_delta_mems = [] if compute_lyapunov else None + + # Track spikes for base and perturbed (if computing Lyapunov) + h_base = None + h_delta = None # Will store (h_perturbed - h_base) + is_first_block = True + + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + # First block: input x is same for both trajectories + h_conv = block.bn(block.conv(x)) # (B, C, H, W) + + # Base trajectory + spk_base, mem_base_new = block.lif(h_conv, base_mems[mem_idx]) + new_base_mems.append(mem_base_new) + h_base = spk_base + + if compute_lyapunov: + # Perturbed trajectory: mem = base + delta + mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx] + spk_perturbed, mem_perturbed_new = block.lif(h_conv, mem_perturbed) + # Store delta for new membrane + new_delta_mems.append(mem_perturbed_new - mem_base_new) + # Store spike difference for propagation + h_delta = spk_perturbed - spk_base + + is_first_block = False + else: + # Subsequent blocks: inputs differ + # Base trajectory + h_conv_base = block.bn(block.conv(h_base)) + spk_base, mem_base_new = block.lif(h_conv_base, base_mems[mem_idx]) + new_base_mems.append(mem_base_new) + + if compute_lyapunov: + # Perturbed trajectory: h_perturbed = h_base + h_delta + h_perturbed = h_base + h_delta + h_conv_perturbed = block.bn(block.conv(h_perturbed)) + mem_perturbed = base_mems[mem_idx] + delta_mems[mem_idx] + spk_perturbed, mem_perturbed_new = block.lif(h_conv_perturbed, mem_perturbed) + new_delta_mems.append(mem_perturbed_new - mem_base_new) + h_delta = spk_perturbed - spk_base + + h_base = spk_base + + mem_idx += 1 + + # Pooling + h_base = pool(h_base) + if compute_lyapunov: + # Pool both and compute new delta + h_perturbed = h_base + pool(h_delta) # Note: pool(base+delta) ≠ pool(base) + pool(delta) in general + # But for AvgPool, it's linear so this is fine + h_delta = h_perturbed - h_base # This simplifies to pool(h_delta) for AvgPool + h_delta = pool(h_delta) # Actually just pool the delta directly (AvgPool is linear) + + # Update membrane states + base_mems = new_base_mems + + # Accumulate spikes from base trajectory + h_flat = h_base.view(B, -1) + if spike_sum is None: + spike_sum = h_flat + else: + spike_sum = spike_sum + h_flat + + # Lyapunov: compute global divergence and renormalize + if compute_lyapunov: + # Global norm of all deltas: ||δ||² = Σ_layers ||δ_layer||² + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for delta in new_delta_mems: + delta_sq = delta_sq + (delta ** 2).sum(dim=(1, 2, 3)) + + delta_norm = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta_norm / lyap_eps + 1e-12) + + # Renormalize: scale all deltas so ||δ||_global = eps + scale = (lyap_eps / delta_norm).view(B, 1, 1, 1) + delta_mems = [delta * scale for delta in new_delta_mems] + + logits = self.fc(spike_sum) + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + +def benchmark_forward(model, x, compute_lyapunov, num_warmup=5, num_runs=20): + """Benchmark forward pass time.""" + device = x.device + + # Warmup + for _ in range(num_warmup): + with torch.no_grad(): + _ = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(num_runs): + torch.cuda.synchronize() + start = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + end = time.perf_counter() + times.append(end - start) + + return times, lyap + + +def benchmark_forward_backward(model, x, y, criterion, compute_lyapunov, + lambda_reg=0.3, num_warmup=5, num_runs=20): + """Benchmark forward + backward pass time.""" + device = x.device + + # Warmup + for _ in range(num_warmup): + model.zero_grad() + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + loss = criterion(logits, y) + if compute_lyapunov and lyap is not None: + loss = loss + lambda_reg * (lyap ** 2) + loss.backward() + + torch.cuda.synchronize() + + # Timed runs + times = [] + for _ in range(num_runs): + model.zero_grad() + torch.cuda.synchronize() + start = time.perf_counter() + + logits, lyap = model(x, compute_lyapunov=compute_lyapunov) + loss = criterion(logits, y) + if compute_lyapunov and lyap is not None: + loss = loss + lambda_reg * (lyap ** 2) + loss.backward() + + torch.cuda.synchronize() + end = time.perf_counter() + times.append(end - start) + + return times + + +def measure_memory(model, x, compute_lyapunov): + """Measure peak GPU memory during forward pass.""" + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + with torch.no_grad(): + _ = model(x, compute_lyapunov=compute_lyapunov) + + torch.cuda.synchronize() + peak_mem = torch.cuda.max_memory_allocated() / 1024**2 # MB + return peak_mem + + +def run_benchmark(): + print("=" * 70) + print("LYAPUNOV COMPUTATION BENCHMARK: Original vs Diff-Only Storage") + print("=" * 70) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + if device.type == "cuda": + print(f"GPU: {torch.cuda.get_device_name()}") + + # Test configurations + configs = [ + {"depth": 4, "blocks_per_stage": 1, "batch_size": 64}, + {"depth": 8, "blocks_per_stage": 2, "batch_size": 64}, + {"depth": 12, "blocks_per_stage": 3, "batch_size": 32}, + ] + + print("\n" + "=" * 70) + + for cfg in configs: + depth = cfg["depth"] + blocks = cfg["blocks_per_stage"] + batch_size = cfg["batch_size"] + + print(f"\n{'='*70}") + print(f"DEPTH = {depth} ({blocks} blocks/stage), Batch = {batch_size}") + print(f"{'='*70}") + + # Create models + model_orig = SpikingVGG_Original( + blocks_per_stage=blocks, T=4 + ).to(device) + + model_diff = SpikingVGG_DiffOnly( + blocks_per_stage=blocks, T=4 + ).to(device) + + # Copy weights from original to diff-only + model_diff.load_state_dict(model_orig.state_dict()) + + print(f"Parameters: {count_parameters(model_orig):,}") + + # Create input + x = torch.randn(batch_size, 3, 32, 32, device=device) + y = torch.randint(0, 100, (batch_size,), device=device) + criterion = nn.CrossEntropyLoss() + + # ============================================================ + # Test 1: Verify outputs match + # ============================================================ + print("\n--- Output Verification ---") + model_orig.eval() + model_diff.eval() + + torch.manual_seed(42) + with torch.no_grad(): + logits_orig, lyap_orig = model_orig(x, compute_lyapunov=True, lyap_eps=1e-4) + + torch.manual_seed(42) + with torch.no_grad(): + logits_diff, lyap_diff = model_diff(x, compute_lyapunov=True, lyap_eps=1e-4) + + logits_match = torch.allclose(logits_orig, logits_diff, rtol=1e-4, atol=1e-5) + lyap_close = abs(lyap_orig.item() - lyap_diff.item()) < 0.1 # Allow some difference due to different implementations + + print(f"Logits match: {logits_match}") + print(f"Lyapunov - Original: {lyap_orig.item():.4f}, Diff-only: {lyap_diff.item():.4f}") + print(f"Lyapunov close (within 0.1): {lyap_close}") + + # ============================================================ + # Test 2: Forward-only speed (no grad) + # ============================================================ + print("\n--- Forward Speed (no_grad) ---") + model_orig.eval() + model_diff.eval() + + # Without Lyapunov + times_orig_noly, _ = benchmark_forward(model_orig, x, compute_lyapunov=False) + times_diff_noly, _ = benchmark_forward(model_diff, x, compute_lyapunov=False) + + mean_orig = sum(times_orig_noly) / len(times_orig_noly) * 1000 + mean_diff = sum(times_diff_noly) / len(times_diff_noly) * 1000 + + print(f" Without Lyapunov:") + print(f" Original: {mean_orig:.2f} ms") + print(f" Diff-only: {mean_diff:.2f} ms") + + # With Lyapunov + times_orig_ly, _ = benchmark_forward(model_orig, x, compute_lyapunov=True) + times_diff_ly, _ = benchmark_forward(model_diff, x, compute_lyapunov=True) + + mean_orig_ly = sum(times_orig_ly) / len(times_orig_ly) * 1000 + mean_diff_ly = sum(times_diff_ly) / len(times_diff_ly) * 1000 + speedup = mean_orig_ly / mean_diff_ly + + print(f" With Lyapunov:") + print(f" Original: {mean_orig_ly:.2f} ms") + print(f" Diff-only: {mean_diff_ly:.2f} ms") + print(f" Speedup: {speedup:.2f}x") + + # ============================================================ + # Test 3: Forward + Backward speed (training mode) + # ============================================================ + print("\n--- Forward+Backward Speed (training) ---") + model_orig.train() + model_diff.train() + + times_orig_train = benchmark_forward_backward( + model_orig, x, y, criterion, compute_lyapunov=True + ) + times_diff_train = benchmark_forward_backward( + model_diff, x, y, criterion, compute_lyapunov=True + ) + + mean_orig_train = sum(times_orig_train) / len(times_orig_train) * 1000 + mean_diff_train = sum(times_diff_train) / len(times_diff_train) * 1000 + speedup_train = mean_orig_train / mean_diff_train + + print(f" With Lyapunov + backward:") + print(f" Original: {mean_orig_train:.2f} ms") + print(f" Diff-only: {mean_diff_train:.2f} ms") + print(f" Speedup: {speedup_train:.2f}x") + + # ============================================================ + # Test 4: Memory usage + # ============================================================ + if device.type == "cuda": + print("\n--- Peak GPU Memory ---") + + mem_orig_noly = measure_memory(model_orig, x, compute_lyapunov=False) + mem_diff_noly = measure_memory(model_diff, x, compute_lyapunov=False) + + mem_orig_ly = measure_memory(model_orig, x, compute_lyapunov=True) + mem_diff_ly = measure_memory(model_diff, x, compute_lyapunov=True) + + print(f" Without Lyapunov:") + print(f" Original: {mem_orig_noly:.1f} MB") + print(f" Diff-only: {mem_diff_noly:.1f} MB") + print(f" With Lyapunov:") + print(f" Original: {mem_orig_ly:.1f} MB") + print(f" Diff-only: {mem_diff_ly:.1f} MB") + print(f" Memory saved: {mem_orig_ly - mem_diff_ly:.1f} MB ({100*(mem_orig_ly - mem_diff_ly)/mem_orig_ly:.1f}%)") + + # Cleanup + del model_orig, model_diff, x, y + torch.cuda.empty_cache() + + print("\n" + "=" * 70) + print("BENCHMARK COMPLETE") + print("=" * 70) + + +if __name__ == "__main__": + run_benchmark() |
