summaryrefslogtreecommitdiff
path: root/files/experiments/depth_scaling_benchmark.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/experiments/depth_scaling_benchmark.py')
-rw-r--r--files/experiments/depth_scaling_benchmark.py1035
1 files changed, 1035 insertions, 0 deletions
diff --git a/files/experiments/depth_scaling_benchmark.py b/files/experiments/depth_scaling_benchmark.py
new file mode 100644
index 0000000..efab140
--- /dev/null
+++ b/files/experiments/depth_scaling_benchmark.py
@@ -0,0 +1,1035 @@
+"""
+Depth Scaling Benchmark: Demonstrate the value of Lyapunov regularization for deep SNNs.
+
+Goal: Show that on complex tasks, shallow SNNs plateau while regulated deep SNNs improve.
+
+Key hypothesis (from literature):
+- Shallow SNNs saturate on complex tasks (CIFAR-100, TinyImageNet)
+- Deep SNNs without regularization fail to train (gradient issues)
+- Deep SNNs WITH Lyapunov regularization achieve higher accuracy
+
+Reference results:
+- Spiking VGG on CIFAR-10: 7 layers ~88%, 13 layers ~91.6% (MDPI)
+- SEW-ResNet-152 on ImageNet: ~69.3% top-1 (NeurIPS)
+- Spikformer on ImageNet: ~74.8% top-1 (arXiv)
+
+Usage:
+ python files/experiments/depth_scaling_benchmark.py --dataset cifar100 --depths 4 8 12 16
+"""
+
+import os
+import sys
+import json
+import time
+from dataclasses import dataclass, asdict
+from typing import Dict, List, Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+_ROOT = os.path.dirname(os.path.dirname(_HERE))
+if _ROOT not in sys.path:
+ sys.path.insert(0, _ROOT)
+
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+from tqdm.auto import tqdm
+
+import snntorch as snn
+from snntorch import surrogate
+
+
+# =============================================================================
+# VGG-style Spiking Network (scalable depth)
+# =============================================================================
+
+class SpikingVGGBlock(nn.Module):
+ """Conv-BN-LIF block for VGG-style architecture."""
+
+ def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None):
+ super().__init__()
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+
+ def forward(self, x, mem):
+ h = self.bn(self.conv(x))
+ spk, mem = self.lif(h, mem)
+ return spk, mem
+
+
+class SpikingVGG(nn.Module):
+ """
+ Scalable VGG-style Spiking Neural Network.
+
+ Architecture follows VGG pattern:
+ - Multiple conv blocks between pooling layers
+ - Depth controlled by num_blocks_per_stage
+
+ Args:
+ in_channels: Input channels (3 for RGB)
+ num_classes: Output classes
+ base_channels: Starting channel count (doubled each stage)
+ num_stages: Number of pooling stages (3-4 typical)
+ blocks_per_stage: Conv blocks per stage (controls depth)
+ T: Number of timesteps
+ beta: LIF membrane decay
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ base_channels: int = 64,
+ num_stages: int = 3,
+ blocks_per_stage: int = 2,
+ T: int = 4,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ dropout: float = 0.25,
+ stable_init: bool = False,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.num_stages = num_stages
+ self.blocks_per_stage = blocks_per_stage
+ self.total_conv_layers = num_stages * blocks_per_stage
+ self.stable_init = stable_init
+
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ # Build stages
+ self.stages = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ ch_in = in_channels
+ ch_out = base_channels
+
+ for stage in range(num_stages):
+ stage_blocks = nn.ModuleList()
+ for b in range(blocks_per_stage):
+ block_in = ch_in if b == 0 else ch_out
+ stage_blocks.append(
+ SpikingVGGBlock(block_in, ch_out, beta, threshold, spike_grad)
+ )
+ self.stages.append(stage_blocks)
+ self.pools.append(nn.AvgPool2d(2))
+ ch_in = ch_out
+ ch_out = min(ch_out * 2, 512) # Cap at 512
+
+ # Calculate spatial size after pooling
+ # Assuming 32x32 input: 32 -> 16 -> 8 -> 4 (for 3 stages)
+ final_spatial = 32 // (2 ** num_stages)
+ final_channels = min(base_channels * (2 ** (num_stages - 1)), 512)
+ fc_input = final_channels * final_spatial * final_spatial
+
+ # Classifier
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(fc_input, num_classes)
+
+ if stable_init:
+ self._init_weights_stable()
+ else:
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_weights_stable(self):
+ """
+ Stability-aware initialization for SNNs.
+
+ Uses smaller weight magnitudes to produce less chaotic initial dynamics.
+ The key insight: Lyapunov exponent depends on weight magnitudes.
+ Smaller weights → smaller gradients → more stable dynamics.
+
+ Strategy:
+ - Use orthogonal init (preserves gradient magnitude across layers)
+ - Scale down by factor of 0.5 to reduce initial chaos
+ - This should produce λ closer to 0 from the start
+ """
+ scale_factor = 0.5 # Reduce weight magnitudes
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # Orthogonal init for conv (reshape to 2D, init, reshape back)
+ weight_shape = m.weight.shape
+ fan_out = weight_shape[0] * weight_shape[2] * weight_shape[3]
+ fan_in = weight_shape[1] * weight_shape[2] * weight_shape[3]
+
+ # Use smaller gain for stability
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ with torch.no_grad():
+ m.weight.mul_(scale_factor)
+
+ elif isinstance(m, nn.Linear):
+ nn.init.orthogonal_(m.weight, gain=scale_factor)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_mems(self, batch_size, device, dtype, P=1):
+ """Initialize membrane potentials for all LIF layers.
+
+ Args:
+ batch_size: Batch size B
+ device: torch device
+ dtype: torch dtype
+ P: Number of trajectories (1=normal, 2=with perturbed for Lyapunov)
+
+ Returns:
+ List of membrane tensors with shape (P, B, C, H, W)
+ """
+ mems = []
+ H, W = 32, 32
+ ch = 64
+
+ for stage in range(self.num_stages):
+ for _ in range(self.blocks_per_stage):
+ mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype))
+ H, W = H // 2, W // 2
+ ch = min(ch * 2, 512)
+
+ return mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+ """
+ Forward pass with optimized Lyapunov computation (Approach A: trajectory batching).
+
+ When compute_lyapunov=True, both original and perturbed trajectories are
+ processed together by batching them along a new dimension P=2. This avoids
+ redundant computation, especially for the first conv layer where inputs are identical.
+
+ Args:
+ x: (B, C, H, W) static image (will be repeated for T steps)
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude
+
+ Returns:
+ logits, lyap_est, recordings
+ """
+ B = x.size(0)
+ device, dtype = x.device, x.dtype
+
+ # P = number of trajectories: 1 for normal, 2 for Lyapunov (original + perturbed)
+ P = 2 if compute_lyapunov else 1
+
+ # Initialize membrane potentials with shape (P, B, C, H, W)
+ mems = self._init_mems(B, device, dtype, P=P)
+
+ # Initialize perturbed trajectory
+ if compute_lyapunov:
+ for i in range(len(mems)):
+ mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0])
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ spike_sum = None
+
+ # Time loop - repeat static input
+ for t in range(self.T):
+ mem_idx = 0
+ new_mems = []
+ is_first_block = True
+
+ # Process through stages
+ for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)):
+ for block in stage_blocks:
+ if is_first_block:
+ # First block: input x is identical for both trajectories
+ # Compute conv+bn ONCE, then expand to (P, B, C, H, W)
+ h_conv = block.bn(block.conv(x)) # (B, C, H, W)
+ h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) # (P, B, C, H, W) zero-copy
+
+ # LIF with batched membrane states
+ # Reshape for LIF: (P, B, C, H, W) -> (P*B, C, H, W)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+ spk_flat, mem_new_flat = block.lif(h_flat, mem_flat)
+
+ # Reshape back: (P*B, C, H, W) -> (P, B, C, H, W)
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+
+ h = spk
+ new_mems.append(mem_new)
+ is_first_block = False
+ else:
+ # Subsequent blocks: inputs differ between trajectories
+ # Batch both trajectories: (P, B, C, H, W) -> (P*B, C, H, W)
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:])
+
+ # Full block forward (conv+bn+lif)
+ h_conv = block.bn(block.conv(h_flat))
+ spk_flat, mem_new_flat = block.lif(h_conv, mem_flat)
+
+ # Reshape back
+ spk = spk_flat.view(P, B, *spk_flat.shape[1:])
+ mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:])
+
+ h = spk
+ new_mems.append(mem_new)
+
+ mem_idx += 1
+
+ # Pool: apply to batched tensor
+ h_flat = h.reshape(P * B, *h.shape[2:])
+ h_pooled = pool(h_flat)
+ h = h_pooled.view(P, B, *h_pooled.shape[1:])
+
+ mems = new_mems
+
+ # Accumulate final spikes from ORIGINAL trajectory only (index 0)
+ h_orig = h[0].view(B, -1) # (B, C*H*W)
+ if spike_sum is None:
+ spike_sum = h_orig
+ else:
+ spike_sum = spike_sum + h_orig
+
+ # Lyapunov divergence and renormalization (Option 1: global delta + global renorm)
+ # This is the textbook Benettin-style Lyapunov exponent estimator where
+ # the perturbation is treated as one vector in the concatenated state space.
+ if compute_lyapunov:
+ # Compute GLOBAL divergence across all layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0] # (B, C, H, W)
+ delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # GLOBAL renormalization: same scale factor for all layers
+ # This ensures ||perturbation||_global = eps after renorm
+ scale = (lyap_eps / delta).view(B, 1, 1, 1) # (B, 1, 1, 1) for broadcasting
+
+ for i in range(len(new_mems)):
+ diff = new_mems[i][1] - new_mems[i][0]
+ # Update perturbed trajectory: scale the diff to have global norm = eps
+ mems[i] = torch.stack([
+ new_mems[i][0],
+ new_mems[i][0] + diff * scale
+ ], dim=0)
+
+ # Readout
+ out = self.dropout(spike_sum)
+ logits = self.fc(out)
+
+ lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est, None
+
+ @property
+ def depth(self):
+ return self.total_conv_layers
+
+
+# =============================================================================
+# Dataset Loading
+# =============================================================================
+
+def get_dataset(
+ name: str,
+ data_dir: str = './data',
+ batch_size: int = 128,
+ num_workers: int = 4,
+) -> Tuple[DataLoader, DataLoader, int, Tuple[int, int, int]]:
+ """
+ Get train/test loaders for various datasets.
+
+ Returns:
+ train_loader, test_loader, num_classes, input_shape
+ """
+
+ if name == 'mnist':
+ transform = transforms.Compose([
+ transforms.Resize(32), # Resize to 32x32 for consistency
+ transforms.ToTensor(),
+ ])
+ train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
+ test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
+ num_classes = 10
+ input_shape = (1, 32, 32)
+
+ elif name == 'fashion_mnist':
+ transform = transforms.Compose([
+ transforms.Resize(32),
+ transforms.ToTensor(),
+ ])
+ train_ds = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform)
+ test_ds = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform)
+ num_classes = 10
+ input_shape = (1, 32, 32)
+
+ elif name == 'cifar10':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ ])
+ transform_test = transforms.Compose([transforms.ToTensor()])
+ train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train)
+ test_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_test)
+ num_classes = 10
+ input_shape = (3, 32, 32)
+
+ elif name == 'cifar100':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ ])
+ transform_test = transforms.Compose([transforms.ToTensor()])
+ train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train)
+ test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_test)
+ num_classes = 100
+ input_shape = (3, 32, 32)
+
+ else:
+ raise ValueError(f"Unknown dataset: {name}")
+
+ train_loader = DataLoader(
+ train_ds, batch_size=batch_size, shuffle=True,
+ num_workers=num_workers, pin_memory=True
+ )
+ test_loader = DataLoader(
+ test_ds, batch_size=batch_size, shuffle=False,
+ num_workers=num_workers, pin_memory=True
+ )
+
+ return train_loader, test_loader, num_classes, input_shape
+
+
+# =============================================================================
+# Training
+# =============================================================================
+
+@dataclass
+class TrainingMetrics:
+ epoch: int
+ train_loss: float
+ train_acc: float
+ test_loss: float
+ test_acc: float
+ lyapunov: Optional[float]
+ grad_norm: float
+ grad_max_sv: Optional[float] # Max singular value of gradients
+ grad_min_sv: Optional[float] # Min singular value of gradients
+ grad_condition: Optional[float] # Condition number
+ lr: float
+ time_sec: float
+
+
+def compute_gradient_svs(model):
+ """Compute gradient singular value statistics for all weight matrices."""
+ max_svs = []
+ min_svs = []
+
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.ndim == 2:
+ with torch.no_grad():
+ G = param.grad.detach()
+ try:
+ sv = torch.linalg.svdvals(G)
+ if len(sv) > 0:
+ max_svs.append(sv[0].item())
+ min_svs.append(sv[-1].item())
+ except Exception:
+ pass
+
+ if not max_svs:
+ return None, None, None
+
+ max_sv = max(max_svs)
+ min_sv = min(min_svs)
+ condition = max_sv / (min_sv + 1e-12)
+
+ return max_sv, min_sv, condition
+
+
+def compute_lyap_reg_loss(lyap_est: torch.Tensor, reg_type: str, lambda_target: float,
+ lyap_threshold: float = 2.0) -> torch.Tensor:
+ """
+ Compute Lyapunov regularization loss with different penalty types.
+
+ Args:
+ lyap_est: Estimated Lyapunov exponent (scalar tensor)
+ reg_type: Type of regularization:
+ - "squared": (λ - target)² - original, aggressive
+ - "hinge": max(0, λ - threshold)² - only penalize chaos
+ - "asymmetric": strong penalty for chaos, weak for collapse
+ - "extreme": only penalize when λ > lyap_threshold (configurable)
+ - "adaptive_linear": penalty scales linearly with excess over threshold
+ - "adaptive_exp": penalty grows exponentially for severe chaos
+ - "adaptive_sigmoid": smooth sigmoid transition around threshold
+ lambda_target: Target value (used for squared reg_type)
+ lyap_threshold: Threshold for adaptive/extreme reg_types (default 2.0)
+
+ Returns:
+ Regularization loss (scalar tensor)
+ """
+ if reg_type == "squared":
+ # Original: penalize any deviation from target
+ return (lyap_est - lambda_target) ** 2
+
+ elif reg_type == "hinge":
+ # Only penalize when λ > threshold (too chaotic)
+ # threshold = 0 means: only penalize positive Lyapunov (chaos)
+ threshold = 0.0
+ excess = torch.relu(lyap_est - threshold)
+ return excess ** 2
+
+ elif reg_type == "asymmetric":
+ # Strong penalty for chaos (λ > 0), weak penalty for collapse (λ < -1)
+ # This allows the network to be stable without being dead
+ chaos_penalty = torch.relu(lyap_est) ** 2 # Penalize λ > 0
+ collapse_penalty = 0.1 * torch.relu(-lyap_est - 1.0) ** 2 # Weakly penalize λ < -1
+ return chaos_penalty + collapse_penalty
+
+ elif reg_type == "extreme":
+ # Only penalize when λ > threshold (VERY chaotic)
+ # This allows moderate chaos while preventing extreme instability
+ # Threshold is now configurable via lyap_threshold argument
+ excess = torch.relu(lyap_est - lyap_threshold)
+ return excess ** 2
+
+ elif reg_type == "adaptive_linear":
+ # Penalty scales linearly with how far above threshold we are
+ # loss = excess * excess² = excess³
+ # This naturally makes the penalty weaker for small excesses
+ # and much stronger for large excesses
+ excess = torch.relu(lyap_est - lyap_threshold)
+ return excess ** 3 # Cubic scaling: gentle near threshold, strong when chaotic
+
+ elif reg_type == "adaptive_exp":
+ # Exponential penalty for severe chaos
+ # loss = (exp(excess) - 1) * excess² for excess > 0
+ # This gives very weak penalty near threshold, explosive penalty for chaos
+ excess = torch.relu(lyap_est - lyap_threshold)
+ # Use exp(excess) - 1 to get 0 when excess=0, exponential growth after
+ exp_scale = torch.exp(excess) - 1.0
+ return exp_scale * excess # exp(excess) * excess - excess
+
+ elif reg_type == "adaptive_sigmoid":
+ # Smooth sigmoid transition around threshold
+ # The "sharpness" of transition is controlled by a temperature parameter
+ # weight(λ) = sigmoid((λ - threshold) / T) where T controls smoothness
+ # Using T=0.5 for moderately sharp transition
+ temperature = 0.5
+ weight = torch.sigmoid((lyap_est - lyap_threshold) / temperature)
+ # Penalize deviation from target, weighted by how far past threshold
+ deviation = lyap_est - lambda_target
+ return weight * (deviation ** 2)
+
+ # =========================================================================
+ # SCALED MULTIPLIER REGULARIZATION
+ # loss = (λ_reg × g(relu(λ))) × relu(λ)
+ # └─────────────────┘ └──────┘
+ # scaled multiplier penalty toward target=0
+ #
+ # The multiplier itself scales with λ, making it mild when λ is small
+ # and aggressive when λ is large.
+ # =========================================================================
+
+ elif reg_type == "mult_linear":
+ # Multiplier scales linearly: g(x) = x
+ # loss = (λ_reg × relu(λ)) × relu(λ) = λ_reg × relu(λ)²
+ # λ=0.5 → 0.25, λ=1.0 → 1.0, λ=2.0 → 4.0, λ=3.0 → 9.0
+ pos_lyap = torch.relu(lyap_est)
+ return pos_lyap * pos_lyap # relu(λ)²
+
+ elif reg_type == "mult_squared":
+ # Multiplier scales quadratically: g(x) = x²
+ # loss = (λ_reg × relu(λ)²) × relu(λ) = λ_reg × relu(λ)³
+ # λ=0.5 → 0.125, λ=1.0 → 1.0, λ=2.0 → 8.0, λ=3.0 → 27.0
+ pos_lyap = torch.relu(lyap_est)
+ return pos_lyap * pos_lyap * pos_lyap # relu(λ)³
+
+ elif reg_type == "mult_log":
+ # Multiplier scales logarithmically: g(x) = log(1+x)
+ # loss = (λ_reg × log(1+relu(λ))) × relu(λ)
+ # λ=0.5 → 0.20, λ=1.0 → 0.69, λ=2.0 → 2.20, λ=3.0 → 4.16
+ pos_lyap = torch.relu(lyap_est)
+ return torch.log1p(pos_lyap) * pos_lyap # log(1+λ) × λ
+
+ else:
+ raise ValueError(f"Unknown reg_type: {reg_type}")
+
+
+def train_epoch(
+ model, loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, lambda_target, lyap_eps,
+ progress=True, compute_sv_every=10,
+ reg_type="squared", current_lambda_reg=None,
+ lyap_threshold=2.0
+):
+ """
+ Train one epoch.
+
+ Args:
+ current_lambda_reg: Actual λ_reg to use (for warmup). If None, uses lambda_reg.
+ reg_type: "squared", "hinge", "asymmetric", or "extreme"
+ lyap_threshold: Threshold for extreme reg_type
+ """
+ model.train()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+ lyap_vals = []
+ grad_norms = []
+ grad_max_svs = []
+ grad_min_svs = []
+ grad_conditions = []
+
+ # Use warmup value if provided
+ effective_lambda_reg = current_lambda_reg if current_lambda_reg is not None else lambda_reg
+
+ iterator = tqdm(loader, desc="train", leave=False) if progress else loader
+
+ for batch_idx, (x, y) in enumerate(iterator):
+ x, y = x.to(device), y.to(device)
+ optimizer.zero_grad()
+
+ logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps)
+
+ loss = criterion(logits, y)
+
+ if use_lyapunov and lyap_est is not None:
+ reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target, lyap_threshold)
+ loss = loss + effective_lambda_reg * reg
+ lyap_vals.append(lyap_est.item())
+
+ if torch.isnan(loss):
+ return float('nan'), 0.0, None, float('nan'), None, None, None
+
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
+
+ grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5
+ grad_norms.append(grad_norm)
+
+ # Compute gradient SVs periodically (expensive)
+ if batch_idx % compute_sv_every == 0:
+ max_sv, min_sv, cond = compute_gradient_svs(model)
+ if max_sv is not None:
+ grad_max_svs.append(max_sv)
+ grad_min_svs.append(min_sv)
+ grad_conditions.append(cond)
+
+ optimizer.step()
+
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ return (
+ total_loss / total,
+ correct / total,
+ np.mean(lyap_vals) if lyap_vals else None,
+ np.mean(grad_norms),
+ np.mean(grad_max_svs) if grad_max_svs else None,
+ np.mean(grad_min_svs) if grad_min_svs else None,
+ np.mean(grad_conditions) if grad_conditions else None,
+ )
+
+
+@torch.no_grad()
+def evaluate(model, loader, criterion, device, progress=True):
+ model.eval()
+ total_loss = 0.0
+ correct = 0
+ total = 0
+
+ iterator = tqdm(loader, desc="eval", leave=False) if progress else loader
+
+ for x, y in iterator:
+ x, y = x.to(device), y.to(device)
+ logits, _, _ = model(x, compute_lyapunov=False)
+
+ loss = criterion(logits, y)
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+
+ return total_loss / total, correct / total
+
+
+def run_single_config(
+ dataset_name: str,
+ depth_config: Tuple[int, int], # (num_stages, blocks_per_stage)
+ use_lyapunov: bool,
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool = True,
+ reg_type: str = "squared",
+ warmup_epochs: int = 0,
+ stable_init: bool = False,
+ lyap_threshold: float = 2.0,
+) -> List[TrainingMetrics]:
+ """Run training for a single configuration."""
+ torch.manual_seed(seed)
+
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ model = SpikingVGG(
+ in_channels=in_channels,
+ num_classes=num_classes,
+ base_channels=64,
+ num_stages=num_stages,
+ blocks_per_stage=blocks_per_stage,
+ T=T,
+ stable_init=stable_init,
+ ).to(device)
+
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ method = "Lyapunov" if use_lyapunov else "Vanilla"
+ print(f" {method}: depth={total_depth}, params={num_params:,}")
+
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+ criterion = nn.CrossEntropyLoss()
+
+ history = []
+ best_acc = 0.0
+
+ for epoch in range(1, epochs + 1):
+ t0 = time.time()
+
+ # Warmup: gradually increase lambda_reg
+ if warmup_epochs > 0 and epoch <= warmup_epochs:
+ current_lambda_reg = lambda_reg * (epoch / warmup_epochs)
+ else:
+ current_lambda_reg = lambda_reg
+
+ train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch(
+ model, train_loader, optimizer, criterion, device,
+ use_lyapunov, lambda_reg, lambda_target, 1e-4, progress,
+ reg_type=reg_type, current_lambda_reg=current_lambda_reg,
+ lyap_threshold=lyap_threshold
+ )
+
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress)
+ scheduler.step()
+
+ dt = time.time() - t0
+ best_acc = max(best_acc, test_acc)
+
+ metrics = TrainingMetrics(
+ epoch=epoch,
+ train_loss=train_loss,
+ train_acc=train_acc,
+ test_loss=test_loss,
+ test_acc=test_acc,
+ lyapunov=lyap,
+ grad_norm=grad_norm,
+ grad_max_sv=grad_max_sv,
+ grad_min_sv=grad_min_sv,
+ grad_condition=grad_cond,
+ lr=scheduler.get_last_lr()[0],
+ time_sec=dt,
+ )
+ history.append(metrics)
+
+ if epoch % 10 == 0 or epoch == epochs:
+ lyap_str = f"λ={lyap:.3f}" if lyap else ""
+ sv_str = f"σ={grad_max_sv:.2e}/{grad_min_sv:.2e}" if grad_max_sv else ""
+ print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} {sv_str}")
+
+ if np.isnan(train_loss):
+ print(f" DIVERGED at epoch {epoch}")
+ break
+
+ print(f" Best test acc: {best_acc:.3f}")
+ return history
+
+
+def run_depth_scaling_experiment(
+ dataset_name: str,
+ depth_configs: List[Tuple[int, int]],
+ train_loader: DataLoader,
+ test_loader: DataLoader,
+ num_classes: int,
+ in_channels: int,
+ T: int,
+ epochs: int,
+ lr: float,
+ lambda_reg: float,
+ lambda_target: float,
+ device: torch.device,
+ seed: int,
+ progress: bool,
+ reg_type: str = "squared",
+ warmup_epochs: int = 0,
+ stable_init: bool = False,
+ lyap_threshold: float = 2.0,
+) -> Dict:
+ """Run full depth scaling experiment."""
+
+ results = {"vanilla": {}, "lyapunov": {}}
+
+ print(f"Regularization type: {reg_type}")
+ print(f"Warmup epochs: {warmup_epochs}")
+ print(f"Stable init: {stable_init}")
+ print(f"Lyapunov threshold: {lyap_threshold}")
+
+ for depth_config in depth_configs:
+ num_stages, blocks_per_stage = depth_config
+ total_depth = num_stages * blocks_per_stage
+
+ print(f"\n{'='*60}")
+ print(f"Depth = {total_depth} conv layers ({num_stages} stages × {blocks_per_stage} blocks)")
+ print(f"{'='*60}")
+
+ for use_lyap in [False, True]:
+ method = "lyapunov" if use_lyap else "vanilla"
+
+ history = run_single_config(
+ dataset_name=dataset_name,
+ depth_config=depth_config,
+ use_lyapunov=use_lyap,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=T,
+ epochs=epochs,
+ lr=lr,
+ lambda_reg=lambda_reg,
+ lambda_target=lambda_target,
+ device=device,
+ seed=seed,
+ progress=progress,
+ reg_type=reg_type,
+ warmup_epochs=warmup_epochs,
+ stable_init=stable_init,
+ lyap_threshold=lyap_threshold,
+ )
+
+ results[method][total_depth] = history
+
+ return results
+
+
+def print_summary(results: Dict, dataset_name: str):
+ """Print final summary table."""
+ print("\n" + "=" * 100)
+ print(f"DEPTH SCALING RESULTS: {dataset_name.upper()}")
+ print("=" * 100)
+ print(f"{'Depth':<8} {'Vanilla Acc':<12} {'Lyapunov Acc':<12} {'Δ Acc':<8} {'Lyap λ':<10} {'Van ∇norm':<12} {'Lyap ∇norm':<12} {'Van κ':<10}")
+ print("-" * 100)
+
+ depths = sorted(results["vanilla"].keys())
+
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_acc = van.test_acc if not np.isnan(van.train_loss) else 0.0
+ lyap_acc = lyap.test_acc if not np.isnan(lyap.train_loss) else 0.0
+
+ diff = lyap_acc - van_acc
+ diff_str = f"+{diff:.3f}" if diff >= 0 else f"{diff:.3f}"
+
+ van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED"
+ lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED"
+ lyap_val = f"{lyap.lyapunov:.3f}" if lyap.lyapunov else "N/A"
+
+ van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A"
+ lyap_grad = f"{lyap.grad_norm:.2e}" if lyap.grad_norm else "N/A"
+ van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A"
+
+ print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<8} {lyap_val:<10} {van_grad:<12} {lyap_grad:<12} {van_cond:<10}")
+
+ print("=" * 100)
+
+ # Gradient health analysis
+ print("\nGRADIENT HEALTH ANALYSIS:")
+ for depth in depths:
+ van = results["vanilla"][depth][-1]
+ lyap = results["lyapunov"][depth][-1]
+
+ van_cond = van.grad_condition if van.grad_condition else 0
+ lyap_cond = lyap.grad_condition if lyap.grad_condition else 0
+
+ status = ""
+ if van_cond > 1e6:
+ status = "⚠️ Vanilla has ill-conditioned gradients (κ > 1e6)"
+ elif van_cond > 1e4:
+ status = "~ Vanilla has moderately ill-conditioned gradients"
+
+ if status:
+ print(f" Depth {depth}: {status}")
+
+ print("")
+
+ # Analysis
+ print("\nKEY OBSERVATIONS:")
+ shallow = min(depths)
+ deep = max(depths)
+
+ van_shallow = results["vanilla"][shallow][-1].test_acc
+ van_deep = results["vanilla"][deep][-1].test_acc
+ lyap_shallow = results["lyapunov"][shallow][-1].test_acc
+ lyap_deep = results["lyapunov"][deep][-1].test_acc
+
+ van_gain = van_deep - van_shallow
+ lyap_gain = lyap_deep - lyap_shallow
+
+ print(f" Vanilla {shallow}→{deep} layers: {van_gain:+.3f} accuracy change")
+ print(f" Lyapunov {shallow}→{deep} layers: {lyap_gain:+.3f} accuracy change")
+
+ if lyap_gain > van_gain + 0.02:
+ print(f" ✓ Lyapunov regularization enables better depth scaling!")
+ elif lyap_gain > van_gain:
+ print(f" ~ Lyapunov shows slight improvement in depth scaling")
+ else:
+ print(f" ✗ No clear benefit from Lyapunov on this dataset/depth range")
+
+
+def save_results(results: Dict, output_dir: str, config: Dict):
+ os.makedirs(output_dir, exist_ok=True)
+
+ serializable = {}
+ for method, depth_results in results.items():
+ serializable[method] = {}
+ for depth, history in depth_results.items():
+ serializable[method][str(depth)] = [asdict(m) for m in history]
+
+ with open(os.path.join(output_dir, "results.json"), "w") as f:
+ json.dump(serializable, f, indent=2)
+
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ print(f"\nResults saved to {output_dir}")
+
+
+def parse_args():
+ p = argparse.ArgumentParser(description="Depth Scaling Benchmark for Lyapunov-Regularized SNNs")
+
+ p.add_argument("--dataset", type=str, default="cifar100",
+ choices=["mnist", "fashion_mnist", "cifar10", "cifar100"])
+ p.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16],
+ help="Total conv layer depths to test")
+ p.add_argument("--T", type=int, default=4, help="Timesteps")
+ p.add_argument("--epochs", type=int, default=100)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=1e-3)
+ p.add_argument("--lambda_reg", type=float, default=0.3)
+ p.add_argument("--lambda_target", type=float, default=-0.1)
+ p.add_argument("--data_dir", type=str, default="./data")
+ p.add_argument("--out_dir", type=str, default="runs/depth_scaling")
+ p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--no-progress", action="store_true")
+ p.add_argument("--reg_type", type=str, default="squared",
+ choices=["squared", "hinge", "asymmetric", "extreme"],
+ help="Lyapunov regularization type")
+ p.add_argument("--warmup_epochs", type=int, default=0,
+ help="Epochs to warmup lambda_reg (0 = no warmup)")
+ p.add_argument("--stable_init", action="store_true",
+ help="Use stability-aware weight initialization")
+ p.add_argument("--lyap_threshold", type=float, default=2.0,
+ help="Threshold for extreme reg_type (only penalize λ > threshold)")
+
+ return p.parse_args()
+
+
+def main():
+ args = parse_args()
+ device = torch.device(args.device)
+
+ print("=" * 80)
+ print("DEPTH SCALING BENCHMARK")
+ print("=" * 80)
+ print(f"Dataset: {args.dataset}")
+ print(f"Depths: {args.depths}")
+ print(f"Timesteps: {args.T}")
+ print(f"Epochs: {args.epochs}")
+ print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}")
+ print(f"Reg type: {args.reg_type}, Warmup epochs: {args.warmup_epochs}")
+ print(f"Device: {device}")
+ print("=" * 80)
+
+ # Load data
+ print(f"\nLoading {args.dataset}...")
+ train_loader, test_loader, num_classes, input_shape = get_dataset(
+ args.dataset, args.data_dir, args.batch_size
+ )
+ in_channels = input_shape[0]
+ print(f"Classes: {num_classes}, Input: {input_shape}")
+ print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}")
+
+ # Convert depths to (num_stages, blocks_per_stage) configs
+ # We use 4 stages (3 for smaller nets), adjust blocks_per_stage
+ depth_configs = []
+ for d in args.depths:
+ if d <= 4:
+ depth_configs.append((d, 1)) # d stages, 1 block each
+ elif d <= 8:
+ depth_configs.append((4, d // 4)) # 4 stages
+ else:
+ depth_configs.append((4, d // 4)) # 4 stages, more blocks
+
+ print(f"\nDepth configurations: {[(d, f'{s}×{b}') for d, (s, b) in zip(args.depths, depth_configs)]}")
+
+ # Run experiment
+ results = run_depth_scaling_experiment(
+ dataset_name=args.dataset,
+ depth_configs=depth_configs,
+ train_loader=train_loader,
+ test_loader=test_loader,
+ num_classes=num_classes,
+ in_channels=in_channels,
+ T=args.T,
+ epochs=args.epochs,
+ lr=args.lr,
+ lambda_reg=args.lambda_reg,
+ lambda_target=args.lambda_target,
+ device=device,
+ seed=args.seed,
+ progress=not args.no_progress,
+ reg_type=args.reg_type,
+ warmup_epochs=args.warmup_epochs,
+ stable_init=args.stable_init,
+ lyap_threshold=args.lyap_threshold,
+ )
+
+ # Summary
+ print_summary(results, args.dataset)
+
+ # Save
+ ts = time.strftime("%Y%m%d-%H%M%S")
+ output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}")
+ save_results(results, output_dir, vars(args))
+
+
+if __name__ == "__main__":
+ main()