summaryrefslogtreecommitdiff
path: root/files/experiments/lyapunov_speedup_benchmark.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
commitcd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch)
tree59a233959932ca0e4f12f196275e07fcf443b33f /files/experiments/lyapunov_speedup_benchmark.py
init commit
Diffstat (limited to 'files/experiments/lyapunov_speedup_benchmark.py')
-rw-r--r--files/experiments/lyapunov_speedup_benchmark.py638
1 files changed, 638 insertions, 0 deletions
diff --git a/files/experiments/lyapunov_speedup_benchmark.py b/files/experiments/lyapunov_speedup_benchmark.py
new file mode 100644
index 0000000..117009b
--- /dev/null
+++ b/files/experiments/lyapunov_speedup_benchmark.py
@@ -0,0 +1,638 @@
+"""
+Lyapunov Computation Speedup Benchmark
+
+Tests different optimization approaches for computing Lyapunov exponents
+during SNN training. All approaches should produce equivalent results
+(within numerical precision) but with different performance characteristics.
+
+Approaches tested:
+- Baseline: Current sequential implementation
+- Approach A: Trajectory-as-batch (P=2), share first Linear
+- Approach B: Global-norm divergence + single-scale renorm
+- Approach C: torch.compile the time loop
+- Combined: A + B + C together
+"""
+
+import os
+import sys
+import time
+from typing import Tuple, Optional, List
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import snntorch as snn
+from snntorch import surrogate
+
+# Ensure we can import from project
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+
+# =============================================================================
+# Baseline Implementation (Current)
+# =============================================================================
+
+class BaselineSNN(nn.Module):
+ """Current implementation: sequential perturbed trajectory."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ # Simple feedforward for benchmarking (not full VGG)
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims # Flattened input
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1) # Flatten
+
+ # Init membrane potentials
+ mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Original trajectory
+ h = x
+ new_mems = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h = lin(h)
+ spk, mem = lif(h, mems[i])
+ new_mems.append(mem)
+ h = spk
+ mems = new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory (SEPARATE PASS - this is slow)
+ h_p = x
+ new_mems_p = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h_p = lin(h_p)
+ spk_p, mem_p = lif(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # Divergence (per-layer norms, then sum)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq += (diff ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize (per-layer - SLOW)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12
+ new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm
+
+ mems_p = new_mems_p
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach A: Trajectory-as-batch (P=2), share first Linear
+# =============================================================================
+
+class ApproachA_SNN(nn.Module):
+ """Batch both trajectories together, share Linear_1."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ P = 2 if compute_lyapunov else 1
+
+ # State layout: (P, B, H) where P=2 for [original, perturbed]
+ mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ # Initialize perturbed state
+ for i in range(len(self.hidden_dims)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Layer 1: compute Linear ONCE, expand to (P, B, H1)
+ h1 = self.linears[0](x) # (B, H1) - computed ONCE
+
+ if compute_lyapunov:
+ h = h1.unsqueeze(0).expand(P, -1, -1) # (P, B, H1) - zero-copy view
+ else:
+ h = h1.unsqueeze(0) # (1, B, H1)
+
+ # LIF layer 1
+ spk, mems[0] = self.lifs[0](h, mems[0])
+ h = spk
+
+ # Layers 2+: different inputs for each trajectory
+ for i in range(1, len(self.hidden_dims)):
+ # Reshape to (P*B, H) for batched Linear
+ h_flat = h.reshape(P * B, -1)
+ h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i])
+ spk, mems[i] = self.lifs[i](h_lin, mems[i])
+ h = spk
+
+ # Accumulate spikes from original trajectory only
+ spike_sum = spike_sum + h[0]
+
+ if compute_lyapunov:
+ # Global divergence across all layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0] # (B, H_i)
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # Renormalize with global scale (per-layer still, but simpler)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ norm = torch.norm(diff, dim=1, keepdim=True) + 1e-12
+ mems[i][1] = mems[i][0] + lyap_eps * diff / norm
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach B: Global-norm divergence + single-scale renorm
+# =============================================================================
+
+class ApproachB_SNN(nn.Module):
+ """Global norm for divergence, single scale factor for renorm."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ mems = [torch.zeros(B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Original trajectory
+ h = x
+ new_mems = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h = lin(h)
+ spk, mem = lif(h, mems[i])
+ new_mems.append(mem)
+ h = spk
+ mems = new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ h_p = x
+ new_mems_p = []
+ for i, (lin, lif) in enumerate(zip(self.linears, self.lifs)):
+ h_p = lin(h_p)
+ spk_p, mem_p = lif(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # GLOBAL divergence (one delta per batch element)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # SINGLE SCALE renormalization (key optimization)
+ scale = (lyap_eps / delta).unsqueeze(-1) # (B, 1)
+ for i in range(len(self.hidden_dims)):
+ diff = new_mems_p[i] - new_mems[i]
+ new_mems_p[i] = new_mems[i] + diff * scale
+
+ mems_p = new_mems_p
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach A+B Combined: Batched trajectories + global renorm
+# =============================================================================
+
+class ApproachAB_SNN(nn.Module):
+ """Combined: trajectory-as-batch + global-norm renorm."""
+
+ def __init__(self, in_channels=3, hidden_dims=[64, 128, 256], T=4, beta=0.9):
+ super().__init__()
+ self.T = T
+ self.hidden_dims = hidden_dims
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+
+ dims = [in_channels * 32 * 32] + hidden_dims
+ for i in range(len(hidden_dims)):
+ self.linears.append(nn.Linear(dims[i], dims[i+1]))
+ self.lifs.append(snn.Leaky(beta=beta, threshold=1.0,
+ spike_grad=spike_grad, init_hidden=False))
+
+ self.readout = nn.Linear(hidden_dims[-1], 10)
+
+ def forward(self, x, compute_lyapunov=False, lyap_eps=1e-4):
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+ x = x.view(B, -1)
+
+ P = 2 if compute_lyapunov else 1
+
+ # State: (P, B, H)
+ mems = [torch.zeros(P, B, h, device=device, dtype=dtype) for h in self.hidden_dims]
+
+ if compute_lyapunov:
+ for i in range(len(self.hidden_dims)):
+ mems[i][1] = lyap_eps * torch.randn(B, self.hidden_dims[i], device=device, dtype=dtype)
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ for t in range(self.T):
+ # Layer 1: Linear computed ONCE
+ h1 = self.linears[0](x)
+ h = h1.unsqueeze(0).expand(P, -1, -1) if compute_lyapunov else h1.unsqueeze(0)
+
+ spk, mems[0] = self.lifs[0](h, mems[0])
+ h = spk
+
+ # Layers 2+
+ for i in range(1, len(self.hidden_dims)):
+ h_flat = h.reshape(P * B, -1)
+ h_lin = self.linears[i](h_flat).view(P, B, self.hidden_dims[i])
+ spk, mems[i] = self.lifs[i](h_lin, mems[i])
+ h = spk
+
+ spike_sum = spike_sum + h[0]
+
+ if compute_lyapunov:
+ # Global divergence
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ delta_sq = delta_sq + diff.square().sum(dim=-1)
+
+ delta = (delta_sq + 1e-12).sqrt()
+ lyap_accum = lyap_accum + (delta / lyap_eps).log()
+
+ # Global scale renorm
+ scale = (lyap_eps / delta).unsqueeze(-1)
+ for i in range(len(self.hidden_dims)):
+ diff = mems[i][1] - mems[i][0]
+ mems[i][1] = mems[i][0] + diff * scale
+
+ logits = self.readout(spike_sum)
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est
+
+
+# =============================================================================
+# Approach C: torch.compile wrapper
+# =============================================================================
+
+def make_compiled_model(model_class, *args, **kwargs):
+ """Create a model and compile its forward pass."""
+ model = model_class(*args, **kwargs)
+ # Compile the forward method
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ return model
+
+
+# =============================================================================
+# Benchmarking
+# =============================================================================
+
+@dataclass
+class BenchmarkResult:
+ name: str
+ forward_time_ms: float
+ backward_time_ms: float
+ total_time_ms: float
+ lyap_value: float
+ memory_mb: float
+
+ def __str__(self):
+ return (f"{self.name:<25} | Fwd: {self.forward_time_ms:7.2f}ms | "
+ f"Bwd: {self.backward_time_ms:7.2f}ms | "
+ f"Total: {self.total_time_ms:7.2f}ms | "
+ f"λ: {self.lyap_value:+.4f} | Mem: {self.memory_mb:.1f}MB")
+
+
+def benchmark_model(
+ model: nn.Module,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ name: str,
+ warmup_iters: int = 5,
+ bench_iters: int = 20,
+) -> BenchmarkResult:
+ """Benchmark a single model configuration."""
+
+ device = x.device
+ criterion = nn.CrossEntropyLoss()
+
+ # Warmup
+ for _ in range(warmup_iters):
+ logits, lyap = model(x, compute_lyapunov=True)
+ loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0)
+ loss.backward()
+ model.zero_grad()
+
+ torch.cuda.synchronize()
+ torch.cuda.reset_peak_memory_stats()
+
+ fwd_times = []
+ bwd_times = []
+ lyap_vals = []
+
+ for _ in range(bench_iters):
+ # Forward
+ torch.cuda.synchronize()
+ t0 = time.perf_counter()
+
+ logits, lyap = model(x, compute_lyapunov=True)
+ loss = criterion(logits, y) + 0.3 * (lyap ** 2 if lyap is not None else 0)
+
+ torch.cuda.synchronize()
+ t1 = time.perf_counter()
+
+ # Backward
+ loss.backward()
+
+ torch.cuda.synchronize()
+ t2 = time.perf_counter()
+
+ fwd_times.append((t1 - t0) * 1000)
+ bwd_times.append((t2 - t1) * 1000)
+ if lyap is not None:
+ lyap_vals.append(lyap.item())
+
+ model.zero_grad()
+
+ peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024
+
+ return BenchmarkResult(
+ name=name,
+ forward_time_ms=sum(fwd_times) / len(fwd_times),
+ backward_time_ms=sum(bwd_times) / len(bwd_times),
+ total_time_ms=sum(fwd_times) / len(fwd_times) + sum(bwd_times) / len(bwd_times),
+ lyap_value=sum(lyap_vals) / len(lyap_vals) if lyap_vals else 0.0,
+ memory_mb=peak_mem,
+ )
+
+
+def run_benchmarks(
+ batch_size: int = 64,
+ T: int = 4,
+ hidden_dims: List[int] = [64, 128, 256],
+ device: str = "cuda",
+):
+ """Run all benchmarks and compare."""
+
+ print("=" * 80)
+ print("LYAPUNOV COMPUTATION SPEEDUP BENCHMARK")
+ print("=" * 80)
+ print(f"Batch size: {batch_size}")
+ print(f"Timesteps: {T}")
+ print(f"Hidden dims: {hidden_dims}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Create dummy data
+ x = torch.randn(batch_size, 3, 32, 32, device=device)
+ y = torch.randint(0, 10, (batch_size,), device=device)
+
+ results = []
+
+ # 1. Baseline
+ print("\n[1/6] Benchmarking Baseline...")
+ model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "Baseline"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 2. Approach A (batched trajectories)
+ print("[2/6] Benchmarking Approach A (batched)...")
+ model = ApproachA_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "A: Batched trajectories"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 3. Approach B (global renorm)
+ print("[3/6] Benchmarking Approach B (global renorm)...")
+ model = ApproachB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "B: Global renorm"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 4. Approach A+B combined
+ print("[4/6] Benchmarking Approach A+B (combined)...")
+ model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ results.append(benchmark_model(model, x, y, "A+B: Combined"))
+ del model
+ torch.cuda.empty_cache()
+
+ # 5. Approach C (torch.compile on baseline)
+ print("[5/6] Benchmarking Approach C (compiled baseline)...")
+ try:
+ model = BaselineSNN(hidden_dims=hidden_dims, T=T).to(device)
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ results.append(benchmark_model(model, x, y, "C: Compiled baseline", warmup_iters=10))
+ del model
+ torch.cuda.empty_cache()
+ except Exception as e:
+ print(f" torch.compile failed: {e}")
+ results.append(BenchmarkResult("C: Compiled baseline", 0, 0, 0, 0, 0))
+
+ # 6. A+B+C (all combined)
+ print("[6/6] Benchmarking A+B+C (all optimizations)...")
+ try:
+ model = ApproachAB_SNN(hidden_dims=hidden_dims, T=T).to(device)
+ model.forward = torch.compile(model.forward, mode="reduce-overhead")
+ results.append(benchmark_model(model, x, y, "A+B+C: All optimized", warmup_iters=10))
+ del model
+ torch.cuda.empty_cache()
+ except Exception as e:
+ print(f" torch.compile failed: {e}")
+ results.append(BenchmarkResult("A+B+C: All optimized", 0, 0, 0, 0, 0))
+
+ # Print results
+ print("\n" + "=" * 80)
+ print("RESULTS")
+ print("=" * 80)
+
+ baseline_time = results[0].total_time_ms
+
+ for r in results:
+ print(r)
+
+ print("\n" + "-" * 80)
+ print("SPEEDUP vs BASELINE:")
+ print("-" * 80)
+
+ for r in results[1:]:
+ if r.total_time_ms > 0:
+ speedup = baseline_time / r.total_time_ms
+ print(f" {r.name:<25}: {speedup:.2f}x")
+
+ # Verify Lyapunov values are consistent
+ print("\n" + "-" * 80)
+ print("LYAPUNOV VALUE CONSISTENCY CHECK:")
+ print("-" * 80)
+
+ base_lyap = results[0].lyap_value
+ for r in results[1:]:
+ if r.lyap_value != 0:
+ diff = abs(r.lyap_value - base_lyap)
+ status = "✓" if diff < 0.1 else "✗"
+ print(f" {r.name:<25}: λ={r.lyap_value:+.4f} (diff={diff:.4f}) {status}")
+
+ return results
+
+
+def run_scaling_test(device: str = "cuda"):
+ """Test how approaches scale with batch size and timesteps."""
+
+ print("\n" + "=" * 80)
+ print("SCALING TESTS")
+ print("=" * 80)
+
+ configs = [
+ {"batch_size": 32, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 128, "T": 4, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 8, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 16, "hidden_dims": [64, 128, 256]},
+ {"batch_size": 64, "T": 4, "hidden_dims": [128, 256, 512]}, # Larger model
+ ]
+
+ print(f"{'Config':<40} | {'Baseline':<12} | {'A+B':<12} | {'Speedup':<8}")
+ print("-" * 80)
+
+ for cfg in configs:
+ x = torch.randn(cfg["batch_size"], 3, 32, 32, device=device)
+ y = torch.randint(0, 10, (cfg["batch_size"],), device=device)
+
+ # Baseline
+ model_base = BaselineSNN(**cfg).to(device)
+ r_base = benchmark_model(model_base, x, y, "base", warmup_iters=3, bench_iters=10)
+ del model_base
+
+ # A+B
+ model_ab = ApproachAB_SNN(**cfg).to(device)
+ r_ab = benchmark_model(model_ab, x, y, "a+b", warmup_iters=3, bench_iters=10)
+ del model_ab
+
+ torch.cuda.empty_cache()
+
+ speedup = r_base.total_time_ms / r_ab.total_time_ms if r_ab.total_time_ms > 0 else 0
+
+ cfg_str = f"B={cfg['batch_size']}, T={cfg['T']}, H={cfg['hidden_dims']}"
+ print(f"{cfg_str:<40} | {r_base.total_time_ms:>10.2f}ms | {r_ab.total_time_ms:>10.2f}ms | {speedup:>6.2f}x")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--batch_size", type=int, default=64)
+ parser.add_argument("--T", type=int, default=4)
+ parser.add_argument("--hidden_dims", type=int, nargs="+", default=[64, 128, 256])
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--scaling", action="store_true", help="Run scaling tests")
+ args = parser.parse_args()
+
+ if not torch.cuda.is_available():
+ print("CUDA not available, using CPU (results will not be representative)")
+ args.device = "cpu"
+
+ # Main benchmark
+ results = run_benchmarks(
+ batch_size=args.batch_size,
+ T=args.T,
+ hidden_dims=args.hidden_dims,
+ device=args.device,
+ )
+
+ # Scaling tests
+ if args.scaling:
+ run_scaling_test(args.device)