""" Depth Scaling Benchmark: Demonstrate the value of Lyapunov regularization for deep SNNs. Goal: Show that on complex tasks, shallow SNNs plateau while regulated deep SNNs improve. Key hypothesis (from literature): - Shallow SNNs saturate on complex tasks (CIFAR-100, TinyImageNet) - Deep SNNs without regularization fail to train (gradient issues) - Deep SNNs WITH Lyapunov regularization achieve higher accuracy Reference results: - Spiking VGG on CIFAR-10: 7 layers ~88%, 13 layers ~91.6% (MDPI) - SEW-ResNet-152 on ImageNet: ~69.3% top-1 (NeurIPS) - Spikformer on ImageNet: ~74.8% top-1 (arXiv) Usage: python files/experiments/depth_scaling_benchmark.py --dataset cifar100 --depths 4 8 12 16 """ import os import sys import json import time from dataclasses import dataclass, asdict from typing import Dict, List, Optional, Tuple _HERE = os.path.dirname(__file__) _ROOT = os.path.dirname(os.path.dirname(_HERE)) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) import argparse import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms from tqdm.auto import tqdm import snntorch as snn from snntorch import surrogate # ============================================================================= # VGG-style Spiking Network (scalable depth) # ============================================================================= class SpikingVGGBlock(nn.Module): """Conv-BN-LIF block for VGG-style architecture.""" def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None): super().__init__() if spike_grad is None: spike_grad = surrogate.fast_sigmoid(slope=25) self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) self.bn = nn.BatchNorm2d(out_ch) self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) def forward(self, x, mem): h = self.bn(self.conv(x)) spk, mem = self.lif(h, mem) return spk, mem class SpikingVGG(nn.Module): """ Scalable VGG-style Spiking Neural Network. Architecture follows VGG pattern: - Multiple conv blocks between pooling layers - Depth controlled by num_blocks_per_stage Args: in_channels: Input channels (3 for RGB) num_classes: Output classes base_channels: Starting channel count (doubled each stage) num_stages: Number of pooling stages (3-4 typical) blocks_per_stage: Conv blocks per stage (controls depth) T: Number of timesteps beta: LIF membrane decay """ def __init__( self, in_channels: int = 3, num_classes: int = 10, base_channels: int = 64, num_stages: int = 3, blocks_per_stage: int = 2, T: int = 4, beta: float = 0.9, threshold: float = 1.0, dropout: float = 0.25, stable_init: bool = False, ): super().__init__() self.T = T self.num_stages = num_stages self.blocks_per_stage = blocks_per_stage self.total_conv_layers = num_stages * blocks_per_stage self.stable_init = stable_init spike_grad = surrogate.fast_sigmoid(slope=25) # Build stages self.stages = nn.ModuleList() self.pools = nn.ModuleList() ch_in = in_channels ch_out = base_channels for stage in range(num_stages): stage_blocks = nn.ModuleList() for b in range(blocks_per_stage): block_in = ch_in if b == 0 else ch_out stage_blocks.append( SpikingVGGBlock(block_in, ch_out, beta, threshold, spike_grad) ) self.stages.append(stage_blocks) self.pools.append(nn.AvgPool2d(2)) ch_in = ch_out ch_out = min(ch_out * 2, 512) # Cap at 512 # Calculate spatial size after pooling # Assuming 32x32 input: 32 -> 16 -> 8 -> 4 (for 3 stages) final_spatial = 32 // (2 ** num_stages) final_channels = min(base_channels * (2 ** (num_stages - 1)), 512) fc_input = final_channels * final_spatial * final_spatial # Classifier self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(fc_input, num_classes) if stable_init: self._init_weights_stable() else: self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def _init_weights_stable(self): """ Stability-aware initialization for SNNs. Uses smaller weight magnitudes to produce less chaotic initial dynamics. The key insight: Lyapunov exponent depends on weight magnitudes. Smaller weights → smaller gradients → more stable dynamics. Strategy: - Use orthogonal init (preserves gradient magnitude across layers) - Scale down by factor of 0.5 to reduce initial chaos - This should produce λ closer to 0 from the start """ scale_factor = 0.5 # Reduce weight magnitudes for m in self.modules(): if isinstance(m, nn.Conv2d): # Orthogonal init for conv (reshape to 2D, init, reshape back) weight_shape = m.weight.shape fan_out = weight_shape[0] * weight_shape[2] * weight_shape[3] fan_in = weight_shape[1] * weight_shape[2] * weight_shape[3] # Use smaller gain for stability nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') with torch.no_grad(): m.weight.mul_(scale_factor) elif isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight, gain=scale_factor) if m.bias is not None: nn.init.zeros_(m.bias) def _init_mems(self, batch_size, device, dtype, P=1): """Initialize membrane potentials for all LIF layers. Args: batch_size: Batch size B device: torch device dtype: torch dtype P: Number of trajectories (1=normal, 2=with perturbed for Lyapunov) Returns: List of membrane tensors with shape (P, B, C, H, W) """ mems = [] H, W = 32, 32 ch = 64 for stage in range(self.num_stages): for _ in range(self.blocks_per_stage): mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype)) H, W = H // 2, W // 2 ch = min(ch * 2, 512) return mems def forward( self, x: torch.Tensor, compute_lyapunov: bool = False, lyap_eps: float = 1e-4, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: """ Forward pass with optimized Lyapunov computation (Approach A: trajectory batching). When compute_lyapunov=True, both original and perturbed trajectories are processed together by batching them along a new dimension P=2. This avoids redundant computation, especially for the first conv layer where inputs are identical. Args: x: (B, C, H, W) static image (will be repeated for T steps) compute_lyapunov: Whether to compute Lyapunov exponent lyap_eps: Perturbation magnitude Returns: logits, lyap_est, recordings """ B = x.size(0) device, dtype = x.device, x.dtype # P = number of trajectories: 1 for normal, 2 for Lyapunov (original + perturbed) P = 2 if compute_lyapunov else 1 # Initialize membrane potentials with shape (P, B, C, H, W) mems = self._init_mems(B, device, dtype, P=P) # Initialize perturbed trajectory if compute_lyapunov: for i in range(len(mems)): mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0]) lyap_accum = torch.zeros(B, device=device, dtype=dtype) spike_sum = None # Time loop - repeat static input for t in range(self.T): mem_idx = 0 new_mems = [] is_first_block = True # Process through stages for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): for block in stage_blocks: if is_first_block: # First block: input x is identical for both trajectories # Compute conv+bn ONCE, then expand to (P, B, C, H, W) h_conv = block.bn(block.conv(x)) # (B, C, H, W) h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) # (P, B, C, H, W) zero-copy # LIF with batched membrane states # Reshape for LIF: (P, B, C, H, W) -> (P*B, C, H, W) h_flat = h.reshape(P * B, *h.shape[2:]) mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) spk_flat, mem_new_flat = block.lif(h_flat, mem_flat) # Reshape back: (P*B, C, H, W) -> (P, B, C, H, W) spk = spk_flat.view(P, B, *spk_flat.shape[1:]) mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) h = spk new_mems.append(mem_new) is_first_block = False else: # Subsequent blocks: inputs differ between trajectories # Batch both trajectories: (P, B, C, H, W) -> (P*B, C, H, W) h_flat = h.reshape(P * B, *h.shape[2:]) mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) # Full block forward (conv+bn+lif) h_conv = block.bn(block.conv(h_flat)) spk_flat, mem_new_flat = block.lif(h_conv, mem_flat) # Reshape back spk = spk_flat.view(P, B, *spk_flat.shape[1:]) mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) h = spk new_mems.append(mem_new) mem_idx += 1 # Pool: apply to batched tensor h_flat = h.reshape(P * B, *h.shape[2:]) h_pooled = pool(h_flat) h = h_pooled.view(P, B, *h_pooled.shape[1:]) mems = new_mems # Accumulate final spikes from ORIGINAL trajectory only (index 0) h_orig = h[0].view(B, -1) # (B, C*H*W) if spike_sum is None: spike_sum = h_orig else: spike_sum = spike_sum + h_orig # Lyapunov divergence and renormalization (Option 1: global delta + global renorm) # This is the textbook Benettin-style Lyapunov exponent estimator where # the perturbation is treated as one vector in the concatenated state space. if compute_lyapunov: # Compute GLOBAL divergence across all layers delta_sq = torch.zeros(B, device=device, dtype=dtype) for i in range(len(new_mems)): diff = new_mems[i][1] - new_mems[i][0] # (B, C, H, W) delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3)) delta = torch.sqrt(delta_sq + 1e-12) lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) # GLOBAL renormalization: same scale factor for all layers # This ensures ||perturbation||_global = eps after renorm scale = (lyap_eps / delta).view(B, 1, 1, 1) # (B, 1, 1, 1) for broadcasting for i in range(len(new_mems)): diff = new_mems[i][1] - new_mems[i][0] # Update perturbed trajectory: scale the diff to have global norm = eps mems[i] = torch.stack([ new_mems[i][0], new_mems[i][0] + diff * scale ], dim=0) # Readout out = self.dropout(spike_sum) logits = self.fc(out) lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None return logits, lyap_est, None @property def depth(self): return self.total_conv_layers # ============================================================================= # Dataset Loading # ============================================================================= def get_dataset( name: str, data_dir: str = './data', batch_size: int = 128, num_workers: int = 4, ) -> Tuple[DataLoader, DataLoader, int, Tuple[int, int, int]]: """ Get train/test loaders for various datasets. Returns: train_loader, test_loader, num_classes, input_shape """ if name == 'mnist': transform = transforms.Compose([ transforms.Resize(32), # Resize to 32x32 for consistency transforms.ToTensor(), ]) train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform) test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform) num_classes = 10 input_shape = (1, 32, 32) elif name == 'fashion_mnist': transform = transforms.Compose([ transforms.Resize(32), transforms.ToTensor(), ]) train_ds = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform) test_ds = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform) num_classes = 10 input_shape = (1, 32, 32) elif name == 'cifar10': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) transform_test = transforms.Compose([transforms.ToTensor()]) train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train) test_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_test) num_classes = 10 input_shape = (3, 32, 32) elif name == 'cifar100': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) transform_test = transforms.Compose([transforms.ToTensor()]) train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train) test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_test) num_classes = 100 input_shape = (3, 32, 32) else: raise ValueError(f"Unknown dataset: {name}") train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True ) test_loader = DataLoader( test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True ) return train_loader, test_loader, num_classes, input_shape # ============================================================================= # Training # ============================================================================= @dataclass class TrainingMetrics: epoch: int train_loss: float train_acc: float test_loss: float test_acc: float lyapunov: Optional[float] grad_norm: float grad_max_sv: Optional[float] # Max singular value of gradients grad_min_sv: Optional[float] # Min singular value of gradients grad_condition: Optional[float] # Condition number lr: float time_sec: float def compute_gradient_svs(model): """Compute gradient singular value statistics for all weight matrices.""" max_svs = [] min_svs = [] for name, param in model.named_parameters(): if param.grad is not None and param.ndim == 2: with torch.no_grad(): G = param.grad.detach() try: sv = torch.linalg.svdvals(G) if len(sv) > 0: max_svs.append(sv[0].item()) min_svs.append(sv[-1].item()) except Exception: pass if not max_svs: return None, None, None max_sv = max(max_svs) min_sv = min(min_svs) condition = max_sv / (min_sv + 1e-12) return max_sv, min_sv, condition def compute_lyap_reg_loss(lyap_est: torch.Tensor, reg_type: str, lambda_target: float, lyap_threshold: float = 2.0) -> torch.Tensor: """ Compute Lyapunov regularization loss with different penalty types. Args: lyap_est: Estimated Lyapunov exponent (scalar tensor) reg_type: Type of regularization: - "squared": (λ - target)² - original, aggressive - "hinge": max(0, λ - threshold)² - only penalize chaos - "asymmetric": strong penalty for chaos, weak for collapse - "extreme": only penalize when λ > lyap_threshold (configurable) - "adaptive_linear": penalty scales linearly with excess over threshold - "adaptive_exp": penalty grows exponentially for severe chaos - "adaptive_sigmoid": smooth sigmoid transition around threshold lambda_target: Target value (used for squared reg_type) lyap_threshold: Threshold for adaptive/extreme reg_types (default 2.0) Returns: Regularization loss (scalar tensor) """ if reg_type == "squared": # Original: penalize any deviation from target return (lyap_est - lambda_target) ** 2 elif reg_type == "hinge": # Only penalize when λ > threshold (too chaotic) # threshold = 0 means: only penalize positive Lyapunov (chaos) threshold = 0.0 excess = torch.relu(lyap_est - threshold) return excess ** 2 elif reg_type == "asymmetric": # Strong penalty for chaos (λ > 0), weak penalty for collapse (λ < -1) # This allows the network to be stable without being dead chaos_penalty = torch.relu(lyap_est) ** 2 # Penalize λ > 0 collapse_penalty = 0.1 * torch.relu(-lyap_est - 1.0) ** 2 # Weakly penalize λ < -1 return chaos_penalty + collapse_penalty elif reg_type == "extreme": # Only penalize when λ > threshold (VERY chaotic) # This allows moderate chaos while preventing extreme instability # Threshold is now configurable via lyap_threshold argument excess = torch.relu(lyap_est - lyap_threshold) return excess ** 2 elif reg_type == "adaptive_linear": # Penalty scales linearly with how far above threshold we are # loss = excess * excess² = excess³ # This naturally makes the penalty weaker for small excesses # and much stronger for large excesses excess = torch.relu(lyap_est - lyap_threshold) return excess ** 3 # Cubic scaling: gentle near threshold, strong when chaotic elif reg_type == "adaptive_exp": # Exponential penalty for severe chaos # loss = (exp(excess) - 1) * excess² for excess > 0 # This gives very weak penalty near threshold, explosive penalty for chaos excess = torch.relu(lyap_est - lyap_threshold) # Use exp(excess) - 1 to get 0 when excess=0, exponential growth after exp_scale = torch.exp(excess) - 1.0 return exp_scale * excess # exp(excess) * excess - excess elif reg_type == "adaptive_sigmoid": # Smooth sigmoid transition around threshold # The "sharpness" of transition is controlled by a temperature parameter # weight(λ) = sigmoid((λ - threshold) / T) where T controls smoothness # Using T=0.5 for moderately sharp transition temperature = 0.5 weight = torch.sigmoid((lyap_est - lyap_threshold) / temperature) # Penalize deviation from target, weighted by how far past threshold deviation = lyap_est - lambda_target return weight * (deviation ** 2) # ========================================================================= # SCALED MULTIPLIER REGULARIZATION # loss = (λ_reg × g(relu(λ))) × relu(λ) # └─────────────────┘ └──────┘ # scaled multiplier penalty toward target=0 # # The multiplier itself scales with λ, making it mild when λ is small # and aggressive when λ is large. # ========================================================================= elif reg_type == "mult_linear": # Multiplier scales linearly: g(x) = x # loss = (λ_reg × relu(λ)) × relu(λ) = λ_reg × relu(λ)² # λ=0.5 → 0.25, λ=1.0 → 1.0, λ=2.0 → 4.0, λ=3.0 → 9.0 pos_lyap = torch.relu(lyap_est) return pos_lyap * pos_lyap # relu(λ)² elif reg_type == "mult_squared": # Multiplier scales quadratically: g(x) = x² # loss = (λ_reg × relu(λ)²) × relu(λ) = λ_reg × relu(λ)³ # λ=0.5 → 0.125, λ=1.0 → 1.0, λ=2.0 → 8.0, λ=3.0 → 27.0 pos_lyap = torch.relu(lyap_est) return pos_lyap * pos_lyap * pos_lyap # relu(λ)³ elif reg_type == "mult_log": # Multiplier scales logarithmically: g(x) = log(1+x) # loss = (λ_reg × log(1+relu(λ))) × relu(λ) # λ=0.5 → 0.20, λ=1.0 → 0.69, λ=2.0 → 2.20, λ=3.0 → 4.16 pos_lyap = torch.relu(lyap_est) return torch.log1p(pos_lyap) * pos_lyap # log(1+λ) × λ else: raise ValueError(f"Unknown reg_type: {reg_type}") def train_epoch( model, loader, optimizer, criterion, device, use_lyapunov, lambda_reg, lambda_target, lyap_eps, progress=True, compute_sv_every=10, reg_type="squared", current_lambda_reg=None, lyap_threshold=2.0 ): """ Train one epoch. Args: current_lambda_reg: Actual λ_reg to use (for warmup). If None, uses lambda_reg. reg_type: "squared", "hinge", "asymmetric", or "extreme" lyap_threshold: Threshold for extreme reg_type """ model.train() total_loss = 0.0 correct = 0 total = 0 lyap_vals = [] grad_norms = [] grad_max_svs = [] grad_min_svs = [] grad_conditions = [] # Use warmup value if provided effective_lambda_reg = current_lambda_reg if current_lambda_reg is not None else lambda_reg iterator = tqdm(loader, desc="train", leave=False) if progress else loader for batch_idx, (x, y) in enumerate(iterator): x, y = x.to(device), y.to(device) optimizer.zero_grad() logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps) loss = criterion(logits, y) if use_lyapunov and lyap_est is not None: reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target, lyap_threshold) loss = loss + effective_lambda_reg * reg lyap_vals.append(lyap_est.item()) if torch.isnan(loss): return float('nan'), 0.0, None, float('nan'), None, None, None loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5 grad_norms.append(grad_norm) # Compute gradient SVs periodically (expensive) if batch_idx % compute_sv_every == 0: max_sv, min_sv, cond = compute_gradient_svs(model) if max_sv is not None: grad_max_svs.append(max_sv) grad_min_svs.append(min_sv) grad_conditions.append(cond) optimizer.step() total_loss += loss.item() * x.size(0) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return ( total_loss / total, correct / total, np.mean(lyap_vals) if lyap_vals else None, np.mean(grad_norms), np.mean(grad_max_svs) if grad_max_svs else None, np.mean(grad_min_svs) if grad_min_svs else None, np.mean(grad_conditions) if grad_conditions else None, ) @torch.no_grad() def evaluate(model, loader, criterion, device, progress=True): model.eval() total_loss = 0.0 correct = 0 total = 0 iterator = tqdm(loader, desc="eval", leave=False) if progress else loader for x, y in iterator: x, y = x.to(device), y.to(device) logits, _, _ = model(x, compute_lyapunov=False) loss = criterion(logits, y) total_loss += loss.item() * x.size(0) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return total_loss / total, correct / total def run_single_config( dataset_name: str, depth_config: Tuple[int, int], # (num_stages, blocks_per_stage) use_lyapunov: bool, train_loader: DataLoader, test_loader: DataLoader, num_classes: int, in_channels: int, T: int, epochs: int, lr: float, lambda_reg: float, lambda_target: float, device: torch.device, seed: int, progress: bool = True, reg_type: str = "squared", warmup_epochs: int = 0, stable_init: bool = False, lyap_threshold: float = 2.0, ) -> List[TrainingMetrics]: """Run training for a single configuration.""" torch.manual_seed(seed) num_stages, blocks_per_stage = depth_config total_depth = num_stages * blocks_per_stage model = SpikingVGG( in_channels=in_channels, num_classes=num_classes, base_channels=64, num_stages=num_stages, blocks_per_stage=blocks_per_stage, T=T, stable_init=stable_init, ).to(device) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) method = "Lyapunov" if use_lyapunov else "Vanilla" print(f" {method}: depth={total_depth}, params={num_params:,}") optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) criterion = nn.CrossEntropyLoss() history = [] best_acc = 0.0 for epoch in range(1, epochs + 1): t0 = time.time() # Warmup: gradually increase lambda_reg if warmup_epochs > 0 and epoch <= warmup_epochs: current_lambda_reg = lambda_reg * (epoch / warmup_epochs) else: current_lambda_reg = lambda_reg train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( model, train_loader, optimizer, criterion, device, use_lyapunov, lambda_reg, lambda_target, 1e-4, progress, reg_type=reg_type, current_lambda_reg=current_lambda_reg, lyap_threshold=lyap_threshold ) test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) scheduler.step() dt = time.time() - t0 best_acc = max(best_acc, test_acc) metrics = TrainingMetrics( epoch=epoch, train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc, lyapunov=lyap, grad_norm=grad_norm, grad_max_sv=grad_max_sv, grad_min_sv=grad_min_sv, grad_condition=grad_cond, lr=scheduler.get_last_lr()[0], time_sec=dt, ) history.append(metrics) if epoch % 10 == 0 or epoch == epochs: lyap_str = f"λ={lyap:.3f}" if lyap else "" sv_str = f"σ={grad_max_sv:.2e}/{grad_min_sv:.2e}" if grad_max_sv else "" print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} {sv_str}") if np.isnan(train_loss): print(f" DIVERGED at epoch {epoch}") break print(f" Best test acc: {best_acc:.3f}") return history def run_depth_scaling_experiment( dataset_name: str, depth_configs: List[Tuple[int, int]], train_loader: DataLoader, test_loader: DataLoader, num_classes: int, in_channels: int, T: int, epochs: int, lr: float, lambda_reg: float, lambda_target: float, device: torch.device, seed: int, progress: bool, reg_type: str = "squared", warmup_epochs: int = 0, stable_init: bool = False, lyap_threshold: float = 2.0, ) -> Dict: """Run full depth scaling experiment.""" results = {"vanilla": {}, "lyapunov": {}} print(f"Regularization type: {reg_type}") print(f"Warmup epochs: {warmup_epochs}") print(f"Stable init: {stable_init}") print(f"Lyapunov threshold: {lyap_threshold}") for depth_config in depth_configs: num_stages, blocks_per_stage = depth_config total_depth = num_stages * blocks_per_stage print(f"\n{'='*60}") print(f"Depth = {total_depth} conv layers ({num_stages} stages × {blocks_per_stage} blocks)") print(f"{'='*60}") for use_lyap in [False, True]: method = "lyapunov" if use_lyap else "vanilla" history = run_single_config( dataset_name=dataset_name, depth_config=depth_config, use_lyapunov=use_lyap, train_loader=train_loader, test_loader=test_loader, num_classes=num_classes, in_channels=in_channels, T=T, epochs=epochs, lr=lr, lambda_reg=lambda_reg, lambda_target=lambda_target, device=device, seed=seed, progress=progress, reg_type=reg_type, warmup_epochs=warmup_epochs, stable_init=stable_init, lyap_threshold=lyap_threshold, ) results[method][total_depth] = history return results def print_summary(results: Dict, dataset_name: str): """Print final summary table.""" print("\n" + "=" * 100) print(f"DEPTH SCALING RESULTS: {dataset_name.upper()}") print("=" * 100) print(f"{'Depth':<8} {'Vanilla Acc':<12} {'Lyapunov Acc':<12} {'Δ Acc':<8} {'Lyap λ':<10} {'Van ∇norm':<12} {'Lyap ∇norm':<12} {'Van κ':<10}") print("-" * 100) depths = sorted(results["vanilla"].keys()) for depth in depths: van = results["vanilla"][depth][-1] lyap = results["lyapunov"][depth][-1] van_acc = van.test_acc if not np.isnan(van.train_loss) else 0.0 lyap_acc = lyap.test_acc if not np.isnan(lyap.train_loss) else 0.0 diff = lyap_acc - van_acc diff_str = f"+{diff:.3f}" if diff >= 0 else f"{diff:.3f}" van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" lyap_val = f"{lyap.lyapunov:.3f}" if lyap.lyapunov else "N/A" van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A" lyap_grad = f"{lyap.grad_norm:.2e}" if lyap.grad_norm else "N/A" van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A" print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<8} {lyap_val:<10} {van_grad:<12} {lyap_grad:<12} {van_cond:<10}") print("=" * 100) # Gradient health analysis print("\nGRADIENT HEALTH ANALYSIS:") for depth in depths: van = results["vanilla"][depth][-1] lyap = results["lyapunov"][depth][-1] van_cond = van.grad_condition if van.grad_condition else 0 lyap_cond = lyap.grad_condition if lyap.grad_condition else 0 status = "" if van_cond > 1e6: status = "⚠️ Vanilla has ill-conditioned gradients (κ > 1e6)" elif van_cond > 1e4: status = "~ Vanilla has moderately ill-conditioned gradients" if status: print(f" Depth {depth}: {status}") print("") # Analysis print("\nKEY OBSERVATIONS:") shallow = min(depths) deep = max(depths) van_shallow = results["vanilla"][shallow][-1].test_acc van_deep = results["vanilla"][deep][-1].test_acc lyap_shallow = results["lyapunov"][shallow][-1].test_acc lyap_deep = results["lyapunov"][deep][-1].test_acc van_gain = van_deep - van_shallow lyap_gain = lyap_deep - lyap_shallow print(f" Vanilla {shallow}→{deep} layers: {van_gain:+.3f} accuracy change") print(f" Lyapunov {shallow}→{deep} layers: {lyap_gain:+.3f} accuracy change") if lyap_gain > van_gain + 0.02: print(f" ✓ Lyapunov regularization enables better depth scaling!") elif lyap_gain > van_gain: print(f" ~ Lyapunov shows slight improvement in depth scaling") else: print(f" ✗ No clear benefit from Lyapunov on this dataset/depth range") def save_results(results: Dict, output_dir: str, config: Dict): os.makedirs(output_dir, exist_ok=True) serializable = {} for method, depth_results in results.items(): serializable[method] = {} for depth, history in depth_results.items(): serializable[method][str(depth)] = [asdict(m) for m in history] with open(os.path.join(output_dir, "results.json"), "w") as f: json.dump(serializable, f, indent=2) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) print(f"\nResults saved to {output_dir}") def parse_args(): p = argparse.ArgumentParser(description="Depth Scaling Benchmark for Lyapunov-Regularized SNNs") p.add_argument("--dataset", type=str, default="cifar100", choices=["mnist", "fashion_mnist", "cifar10", "cifar100"]) p.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16], help="Total conv layer depths to test") p.add_argument("--T", type=int, default=4, help="Timesteps") p.add_argument("--epochs", type=int, default=100) p.add_argument("--batch_size", type=int, default=128) p.add_argument("--lr", type=float, default=1e-3) p.add_argument("--lambda_reg", type=float, default=0.3) p.add_argument("--lambda_target", type=float, default=-0.1) p.add_argument("--data_dir", type=str, default="./data") p.add_argument("--out_dir", type=str, default="runs/depth_scaling") p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--seed", type=int, default=42) p.add_argument("--no-progress", action="store_true") p.add_argument("--reg_type", type=str, default="squared", choices=["squared", "hinge", "asymmetric", "extreme"], help="Lyapunov regularization type") p.add_argument("--warmup_epochs", type=int, default=0, help="Epochs to warmup lambda_reg (0 = no warmup)") p.add_argument("--stable_init", action="store_true", help="Use stability-aware weight initialization") p.add_argument("--lyap_threshold", type=float, default=2.0, help="Threshold for extreme reg_type (only penalize λ > threshold)") return p.parse_args() def main(): args = parse_args() device = torch.device(args.device) print("=" * 80) print("DEPTH SCALING BENCHMARK") print("=" * 80) print(f"Dataset: {args.dataset}") print(f"Depths: {args.depths}") print(f"Timesteps: {args.T}") print(f"Epochs: {args.epochs}") print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}") print(f"Reg type: {args.reg_type}, Warmup epochs: {args.warmup_epochs}") print(f"Device: {device}") print("=" * 80) # Load data print(f"\nLoading {args.dataset}...") train_loader, test_loader, num_classes, input_shape = get_dataset( args.dataset, args.data_dir, args.batch_size ) in_channels = input_shape[0] print(f"Classes: {num_classes}, Input: {input_shape}") print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") # Convert depths to (num_stages, blocks_per_stage) configs # We use 4 stages (3 for smaller nets), adjust blocks_per_stage depth_configs = [] for d in args.depths: if d <= 4: depth_configs.append((d, 1)) # d stages, 1 block each elif d <= 8: depth_configs.append((4, d // 4)) # 4 stages else: depth_configs.append((4, d // 4)) # 4 stages, more blocks print(f"\nDepth configurations: {[(d, f'{s}×{b}') for d, (s, b) in zip(args.depths, depth_configs)]}") # Run experiment results = run_depth_scaling_experiment( dataset_name=args.dataset, depth_configs=depth_configs, train_loader=train_loader, test_loader=test_loader, num_classes=num_classes, in_channels=in_channels, T=args.T, epochs=args.epochs, lr=args.lr, lambda_reg=args.lambda_reg, lambda_target=args.lambda_target, device=device, seed=args.seed, progress=not args.no_progress, reg_type=args.reg_type, warmup_epochs=args.warmup_epochs, stable_init=args.stable_init, lyap_threshold=args.lyap_threshold, ) # Summary print_summary(results, args.dataset) # Save ts = time.strftime("%Y%m%d-%H%M%S") output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}") save_results(results, output_dir, vars(args)) if __name__ == "__main__": main()