From cd99d6b874d9d09b3bb87b8485cc787885af71f1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 13 Jan 2026 23:49:05 -0600 Subject: init commit --- files/experiments/depth_scaling_benchmark.py | 1035 ++++++++++++++++++++++++++ 1 file changed, 1035 insertions(+) create mode 100644 files/experiments/depth_scaling_benchmark.py (limited to 'files/experiments/depth_scaling_benchmark.py') diff --git a/files/experiments/depth_scaling_benchmark.py b/files/experiments/depth_scaling_benchmark.py new file mode 100644 index 0000000..efab140 --- /dev/null +++ b/files/experiments/depth_scaling_benchmark.py @@ -0,0 +1,1035 @@ +""" +Depth Scaling Benchmark: Demonstrate the value of Lyapunov regularization for deep SNNs. + +Goal: Show that on complex tasks, shallow SNNs plateau while regulated deep SNNs improve. + +Key hypothesis (from literature): +- Shallow SNNs saturate on complex tasks (CIFAR-100, TinyImageNet) +- Deep SNNs without regularization fail to train (gradient issues) +- Deep SNNs WITH Lyapunov regularization achieve higher accuracy + +Reference results: +- Spiking VGG on CIFAR-10: 7 layers ~88%, 13 layers ~91.6% (MDPI) +- SEW-ResNet-152 on ImageNet: ~69.3% top-1 (NeurIPS) +- Spikformer on ImageNet: ~74.8% top-1 (arXiv) + +Usage: + python files/experiments/depth_scaling_benchmark.py --dataset cifar100 --depths 4 8 12 16 +""" + +import os +import sys +import json +import time +from dataclasses import dataclass, asdict +from typing import Dict, List, Optional, Tuple + +_HERE = os.path.dirname(__file__) +_ROOT = os.path.dirname(os.path.dirname(_HERE)) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm.auto import tqdm + +import snntorch as snn +from snntorch import surrogate + + +# ============================================================================= +# VGG-style Spiking Network (scalable depth) +# ============================================================================= + +class SpikingVGGBlock(nn.Module): + """Conv-BN-LIF block for VGG-style architecture.""" + + def __init__(self, in_ch, out_ch, beta=0.9, threshold=1.0, spike_grad=None): + super().__init__() + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) + self.bn = nn.BatchNorm2d(out_ch) + self.lif = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + + def forward(self, x, mem): + h = self.bn(self.conv(x)) + spk, mem = self.lif(h, mem) + return spk, mem + + +class SpikingVGG(nn.Module): + """ + Scalable VGG-style Spiking Neural Network. + + Architecture follows VGG pattern: + - Multiple conv blocks between pooling layers + - Depth controlled by num_blocks_per_stage + + Args: + in_channels: Input channels (3 for RGB) + num_classes: Output classes + base_channels: Starting channel count (doubled each stage) + num_stages: Number of pooling stages (3-4 typical) + blocks_per_stage: Conv blocks per stage (controls depth) + T: Number of timesteps + beta: LIF membrane decay + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + base_channels: int = 64, + num_stages: int = 3, + blocks_per_stage: int = 2, + T: int = 4, + beta: float = 0.9, + threshold: float = 1.0, + dropout: float = 0.25, + stable_init: bool = False, + ): + super().__init__() + + self.T = T + self.num_stages = num_stages + self.blocks_per_stage = blocks_per_stage + self.total_conv_layers = num_stages * blocks_per_stage + self.stable_init = stable_init + + spike_grad = surrogate.fast_sigmoid(slope=25) + + # Build stages + self.stages = nn.ModuleList() + self.pools = nn.ModuleList() + + ch_in = in_channels + ch_out = base_channels + + for stage in range(num_stages): + stage_blocks = nn.ModuleList() + for b in range(blocks_per_stage): + block_in = ch_in if b == 0 else ch_out + stage_blocks.append( + SpikingVGGBlock(block_in, ch_out, beta, threshold, spike_grad) + ) + self.stages.append(stage_blocks) + self.pools.append(nn.AvgPool2d(2)) + ch_in = ch_out + ch_out = min(ch_out * 2, 512) # Cap at 512 + + # Calculate spatial size after pooling + # Assuming 32x32 input: 32 -> 16 -> 8 -> 4 (for 3 stages) + final_spatial = 32 // (2 ** num_stages) + final_channels = min(base_channels * (2 ** (num_stages - 1)), 512) + fc_input = final_channels * final_spatial * final_spatial + + # Classifier + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(fc_input, num_classes) + + if stable_init: + self._init_weights_stable() + else: + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_weights_stable(self): + """ + Stability-aware initialization for SNNs. + + Uses smaller weight magnitudes to produce less chaotic initial dynamics. + The key insight: Lyapunov exponent depends on weight magnitudes. + Smaller weights → smaller gradients → more stable dynamics. + + Strategy: + - Use orthogonal init (preserves gradient magnitude across layers) + - Scale down by factor of 0.5 to reduce initial chaos + - This should produce λ closer to 0 from the start + """ + scale_factor = 0.5 # Reduce weight magnitudes + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # Orthogonal init for conv (reshape to 2D, init, reshape back) + weight_shape = m.weight.shape + fan_out = weight_shape[0] * weight_shape[2] * weight_shape[3] + fan_in = weight_shape[1] * weight_shape[2] * weight_shape[3] + + # Use smaller gain for stability + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + with torch.no_grad(): + m.weight.mul_(scale_factor) + + elif isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight, gain=scale_factor) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_mems(self, batch_size, device, dtype, P=1): + """Initialize membrane potentials for all LIF layers. + + Args: + batch_size: Batch size B + device: torch device + dtype: torch dtype + P: Number of trajectories (1=normal, 2=with perturbed for Lyapunov) + + Returns: + List of membrane tensors with shape (P, B, C, H, W) + """ + mems = [] + H, W = 32, 32 + ch = 64 + + for stage in range(self.num_stages): + for _ in range(self.blocks_per_stage): + mems.append(torch.zeros(P, batch_size, ch, H, W, device=device, dtype=dtype)) + H, W = H // 2, W // 2 + ch = min(ch * 2, 512) + + return mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + """ + Forward pass with optimized Lyapunov computation (Approach A: trajectory batching). + + When compute_lyapunov=True, both original and perturbed trajectories are + processed together by batching them along a new dimension P=2. This avoids + redundant computation, especially for the first conv layer where inputs are identical. + + Args: + x: (B, C, H, W) static image (will be repeated for T steps) + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude + + Returns: + logits, lyap_est, recordings + """ + B = x.size(0) + device, dtype = x.device, x.dtype + + # P = number of trajectories: 1 for normal, 2 for Lyapunov (original + perturbed) + P = 2 if compute_lyapunov else 1 + + # Initialize membrane potentials with shape (P, B, C, H, W) + mems = self._init_mems(B, device, dtype, P=P) + + # Initialize perturbed trajectory + if compute_lyapunov: + for i in range(len(mems)): + mems[i][1] = mems[i][0] + lyap_eps * torch.randn_like(mems[i][0]) + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + spike_sum = None + + # Time loop - repeat static input + for t in range(self.T): + mem_idx = 0 + new_mems = [] + is_first_block = True + + # Process through stages + for stage_idx, (stage_blocks, pool) in enumerate(zip(self.stages, self.pools)): + for block in stage_blocks: + if is_first_block: + # First block: input x is identical for both trajectories + # Compute conv+bn ONCE, then expand to (P, B, C, H, W) + h_conv = block.bn(block.conv(x)) # (B, C, H, W) + h = h_conv.unsqueeze(0).expand(P, -1, -1, -1, -1) # (P, B, C, H, W) zero-copy + + # LIF with batched membrane states + # Reshape for LIF: (P, B, C, H, W) -> (P*B, C, H, W) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + spk_flat, mem_new_flat = block.lif(h_flat, mem_flat) + + # Reshape back: (P*B, C, H, W) -> (P, B, C, H, W) + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + + h = spk + new_mems.append(mem_new) + is_first_block = False + else: + # Subsequent blocks: inputs differ between trajectories + # Batch both trajectories: (P, B, C, H, W) -> (P*B, C, H, W) + h_flat = h.reshape(P * B, *h.shape[2:]) + mem_flat = mems[mem_idx].reshape(P * B, *mems[mem_idx].shape[2:]) + + # Full block forward (conv+bn+lif) + h_conv = block.bn(block.conv(h_flat)) + spk_flat, mem_new_flat = block.lif(h_conv, mem_flat) + + # Reshape back + spk = spk_flat.view(P, B, *spk_flat.shape[1:]) + mem_new = mem_new_flat.view(P, B, *mem_new_flat.shape[1:]) + + h = spk + new_mems.append(mem_new) + + mem_idx += 1 + + # Pool: apply to batched tensor + h_flat = h.reshape(P * B, *h.shape[2:]) + h_pooled = pool(h_flat) + h = h_pooled.view(P, B, *h_pooled.shape[1:]) + + mems = new_mems + + # Accumulate final spikes from ORIGINAL trajectory only (index 0) + h_orig = h[0].view(B, -1) # (B, C*H*W) + if spike_sum is None: + spike_sum = h_orig + else: + spike_sum = spike_sum + h_orig + + # Lyapunov divergence and renormalization (Option 1: global delta + global renorm) + # This is the textbook Benettin-style Lyapunov exponent estimator where + # the perturbation is treated as one vector in the concatenated state space. + if compute_lyapunov: + # Compute GLOBAL divergence across all layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] # (B, C, H, W) + delta_sq = delta_sq + (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # GLOBAL renormalization: same scale factor for all layers + # This ensures ||perturbation||_global = eps after renorm + scale = (lyap_eps / delta).view(B, 1, 1, 1) # (B, 1, 1, 1) for broadcasting + + for i in range(len(new_mems)): + diff = new_mems[i][1] - new_mems[i][0] + # Update perturbed trajectory: scale the diff to have global norm = eps + mems[i] = torch.stack([ + new_mems[i][0], + new_mems[i][0] + diff * scale + ], dim=0) + + # Readout + out = self.dropout(spike_sum) + logits = self.fc(out) + + lyap_est = (lyap_accum / self.T).mean() if compute_lyapunov else None + + return logits, lyap_est, None + + @property + def depth(self): + return self.total_conv_layers + + +# ============================================================================= +# Dataset Loading +# ============================================================================= + +def get_dataset( + name: str, + data_dir: str = './data', + batch_size: int = 128, + num_workers: int = 4, +) -> Tuple[DataLoader, DataLoader, int, Tuple[int, int, int]]: + """ + Get train/test loaders for various datasets. + + Returns: + train_loader, test_loader, num_classes, input_shape + """ + + if name == 'mnist': + transform = transforms.Compose([ + transforms.Resize(32), # Resize to 32x32 for consistency + transforms.ToTensor(), + ]) + train_ds = datasets.MNIST(data_dir, train=True, download=True, transform=transform) + test_ds = datasets.MNIST(data_dir, train=False, download=True, transform=transform) + num_classes = 10 + input_shape = (1, 32, 32) + + elif name == 'fashion_mnist': + transform = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + ]) + train_ds = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform) + test_ds = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform) + num_classes = 10 + input_shape = (1, 32, 32) + + elif name == 'cifar10': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ]) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_ds = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform_train) + test_ds = datasets.CIFAR10(data_dir, train=False, download=True, transform=transform_test) + num_classes = 10 + input_shape = (3, 32, 32) + + elif name == 'cifar100': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ]) + transform_test = transforms.Compose([transforms.ToTensor()]) + train_ds = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform_train) + test_ds = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform_test) + num_classes = 100 + input_shape = (3, 32, 32) + + else: + raise ValueError(f"Unknown dataset: {name}") + + train_loader = DataLoader( + train_ds, batch_size=batch_size, shuffle=True, + num_workers=num_workers, pin_memory=True + ) + test_loader = DataLoader( + test_ds, batch_size=batch_size, shuffle=False, + num_workers=num_workers, pin_memory=True + ) + + return train_loader, test_loader, num_classes, input_shape + + +# ============================================================================= +# Training +# ============================================================================= + +@dataclass +class TrainingMetrics: + epoch: int + train_loss: float + train_acc: float + test_loss: float + test_acc: float + lyapunov: Optional[float] + grad_norm: float + grad_max_sv: Optional[float] # Max singular value of gradients + grad_min_sv: Optional[float] # Min singular value of gradients + grad_condition: Optional[float] # Condition number + lr: float + time_sec: float + + +def compute_gradient_svs(model): + """Compute gradient singular value statistics for all weight matrices.""" + max_svs = [] + min_svs = [] + + for name, param in model.named_parameters(): + if param.grad is not None and param.ndim == 2: + with torch.no_grad(): + G = param.grad.detach() + try: + sv = torch.linalg.svdvals(G) + if len(sv) > 0: + max_svs.append(sv[0].item()) + min_svs.append(sv[-1].item()) + except Exception: + pass + + if not max_svs: + return None, None, None + + max_sv = max(max_svs) + min_sv = min(min_svs) + condition = max_sv / (min_sv + 1e-12) + + return max_sv, min_sv, condition + + +def compute_lyap_reg_loss(lyap_est: torch.Tensor, reg_type: str, lambda_target: float, + lyap_threshold: float = 2.0) -> torch.Tensor: + """ + Compute Lyapunov regularization loss with different penalty types. + + Args: + lyap_est: Estimated Lyapunov exponent (scalar tensor) + reg_type: Type of regularization: + - "squared": (λ - target)² - original, aggressive + - "hinge": max(0, λ - threshold)² - only penalize chaos + - "asymmetric": strong penalty for chaos, weak for collapse + - "extreme": only penalize when λ > lyap_threshold (configurable) + - "adaptive_linear": penalty scales linearly with excess over threshold + - "adaptive_exp": penalty grows exponentially for severe chaos + - "adaptive_sigmoid": smooth sigmoid transition around threshold + lambda_target: Target value (used for squared reg_type) + lyap_threshold: Threshold for adaptive/extreme reg_types (default 2.0) + + Returns: + Regularization loss (scalar tensor) + """ + if reg_type == "squared": + # Original: penalize any deviation from target + return (lyap_est - lambda_target) ** 2 + + elif reg_type == "hinge": + # Only penalize when λ > threshold (too chaotic) + # threshold = 0 means: only penalize positive Lyapunov (chaos) + threshold = 0.0 + excess = torch.relu(lyap_est - threshold) + return excess ** 2 + + elif reg_type == "asymmetric": + # Strong penalty for chaos (λ > 0), weak penalty for collapse (λ < -1) + # This allows the network to be stable without being dead + chaos_penalty = torch.relu(lyap_est) ** 2 # Penalize λ > 0 + collapse_penalty = 0.1 * torch.relu(-lyap_est - 1.0) ** 2 # Weakly penalize λ < -1 + return chaos_penalty + collapse_penalty + + elif reg_type == "extreme": + # Only penalize when λ > threshold (VERY chaotic) + # This allows moderate chaos while preventing extreme instability + # Threshold is now configurable via lyap_threshold argument + excess = torch.relu(lyap_est - lyap_threshold) + return excess ** 2 + + elif reg_type == "adaptive_linear": + # Penalty scales linearly with how far above threshold we are + # loss = excess * excess² = excess³ + # This naturally makes the penalty weaker for small excesses + # and much stronger for large excesses + excess = torch.relu(lyap_est - lyap_threshold) + return excess ** 3 # Cubic scaling: gentle near threshold, strong when chaotic + + elif reg_type == "adaptive_exp": + # Exponential penalty for severe chaos + # loss = (exp(excess) - 1) * excess² for excess > 0 + # This gives very weak penalty near threshold, explosive penalty for chaos + excess = torch.relu(lyap_est - lyap_threshold) + # Use exp(excess) - 1 to get 0 when excess=0, exponential growth after + exp_scale = torch.exp(excess) - 1.0 + return exp_scale * excess # exp(excess) * excess - excess + + elif reg_type == "adaptive_sigmoid": + # Smooth sigmoid transition around threshold + # The "sharpness" of transition is controlled by a temperature parameter + # weight(λ) = sigmoid((λ - threshold) / T) where T controls smoothness + # Using T=0.5 for moderately sharp transition + temperature = 0.5 + weight = torch.sigmoid((lyap_est - lyap_threshold) / temperature) + # Penalize deviation from target, weighted by how far past threshold + deviation = lyap_est - lambda_target + return weight * (deviation ** 2) + + # ========================================================================= + # SCALED MULTIPLIER REGULARIZATION + # loss = (λ_reg × g(relu(λ))) × relu(λ) + # └─────────────────┘ └──────┘ + # scaled multiplier penalty toward target=0 + # + # The multiplier itself scales with λ, making it mild when λ is small + # and aggressive when λ is large. + # ========================================================================= + + elif reg_type == "mult_linear": + # Multiplier scales linearly: g(x) = x + # loss = (λ_reg × relu(λ)) × relu(λ) = λ_reg × relu(λ)² + # λ=0.5 → 0.25, λ=1.0 → 1.0, λ=2.0 → 4.0, λ=3.0 → 9.0 + pos_lyap = torch.relu(lyap_est) + return pos_lyap * pos_lyap # relu(λ)² + + elif reg_type == "mult_squared": + # Multiplier scales quadratically: g(x) = x² + # loss = (λ_reg × relu(λ)²) × relu(λ) = λ_reg × relu(λ)³ + # λ=0.5 → 0.125, λ=1.0 → 1.0, λ=2.0 → 8.0, λ=3.0 → 27.0 + pos_lyap = torch.relu(lyap_est) + return pos_lyap * pos_lyap * pos_lyap # relu(λ)³ + + elif reg_type == "mult_log": + # Multiplier scales logarithmically: g(x) = log(1+x) + # loss = (λ_reg × log(1+relu(λ))) × relu(λ) + # λ=0.5 → 0.20, λ=1.0 → 0.69, λ=2.0 → 2.20, λ=3.0 → 4.16 + pos_lyap = torch.relu(lyap_est) + return torch.log1p(pos_lyap) * pos_lyap # log(1+λ) × λ + + else: + raise ValueError(f"Unknown reg_type: {reg_type}") + + +def train_epoch( + model, loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, lambda_target, lyap_eps, + progress=True, compute_sv_every=10, + reg_type="squared", current_lambda_reg=None, + lyap_threshold=2.0 +): + """ + Train one epoch. + + Args: + current_lambda_reg: Actual λ_reg to use (for warmup). If None, uses lambda_reg. + reg_type: "squared", "hinge", "asymmetric", or "extreme" + lyap_threshold: Threshold for extreme reg_type + """ + model.train() + total_loss = 0.0 + correct = 0 + total = 0 + lyap_vals = [] + grad_norms = [] + grad_max_svs = [] + grad_min_svs = [] + grad_conditions = [] + + # Use warmup value if provided + effective_lambda_reg = current_lambda_reg if current_lambda_reg is not None else lambda_reg + + iterator = tqdm(loader, desc="train", leave=False) if progress else loader + + for batch_idx, (x, y) in enumerate(iterator): + x, y = x.to(device), y.to(device) + optimizer.zero_grad() + + logits, lyap_est, _ = model(x, compute_lyapunov=use_lyapunov, lyap_eps=lyap_eps) + + loss = criterion(logits, y) + + if use_lyapunov and lyap_est is not None: + reg = compute_lyap_reg_loss(lyap_est, reg_type, lambda_target, lyap_threshold) + loss = loss + effective_lambda_reg * reg + lyap_vals.append(lyap_est.item()) + + if torch.isnan(loss): + return float('nan'), 0.0, None, float('nan'), None, None, None + + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) + + grad_norm = sum(p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None)**0.5 + grad_norms.append(grad_norm) + + # Compute gradient SVs periodically (expensive) + if batch_idx % compute_sv_every == 0: + max_sv, min_sv, cond = compute_gradient_svs(model) + if max_sv is not None: + grad_max_svs.append(max_sv) + grad_min_svs.append(min_sv) + grad_conditions.append(cond) + + optimizer.step() + + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + return ( + total_loss / total, + correct / total, + np.mean(lyap_vals) if lyap_vals else None, + np.mean(grad_norms), + np.mean(grad_max_svs) if grad_max_svs else None, + np.mean(grad_min_svs) if grad_min_svs else None, + np.mean(grad_conditions) if grad_conditions else None, + ) + + +@torch.no_grad() +def evaluate(model, loader, criterion, device, progress=True): + model.eval() + total_loss = 0.0 + correct = 0 + total = 0 + + iterator = tqdm(loader, desc="eval", leave=False) if progress else loader + + for x, y in iterator: + x, y = x.to(device), y.to(device) + logits, _, _ = model(x, compute_lyapunov=False) + + loss = criterion(logits, y) + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + + return total_loss / total, correct / total + + +def run_single_config( + dataset_name: str, + depth_config: Tuple[int, int], # (num_stages, blocks_per_stage) + use_lyapunov: bool, + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool = True, + reg_type: str = "squared", + warmup_epochs: int = 0, + stable_init: bool = False, + lyap_threshold: float = 2.0, +) -> List[TrainingMetrics]: + """Run training for a single configuration.""" + torch.manual_seed(seed) + + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + model = SpikingVGG( + in_channels=in_channels, + num_classes=num_classes, + base_channels=64, + num_stages=num_stages, + blocks_per_stage=blocks_per_stage, + T=T, + stable_init=stable_init, + ).to(device) + + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + method = "Lyapunov" if use_lyapunov else "Vanilla" + print(f" {method}: depth={total_depth}, params={num_params:,}") + + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + criterion = nn.CrossEntropyLoss() + + history = [] + best_acc = 0.0 + + for epoch in range(1, epochs + 1): + t0 = time.time() + + # Warmup: gradually increase lambda_reg + if warmup_epochs > 0 and epoch <= warmup_epochs: + current_lambda_reg = lambda_reg * (epoch / warmup_epochs) + else: + current_lambda_reg = lambda_reg + + train_loss, train_acc, lyap, grad_norm, grad_max_sv, grad_min_sv, grad_cond = train_epoch( + model, train_loader, optimizer, criterion, device, + use_lyapunov, lambda_reg, lambda_target, 1e-4, progress, + reg_type=reg_type, current_lambda_reg=current_lambda_reg, + lyap_threshold=lyap_threshold + ) + + test_loss, test_acc = evaluate(model, test_loader, criterion, device, progress) + scheduler.step() + + dt = time.time() - t0 + best_acc = max(best_acc, test_acc) + + metrics = TrainingMetrics( + epoch=epoch, + train_loss=train_loss, + train_acc=train_acc, + test_loss=test_loss, + test_acc=test_acc, + lyapunov=lyap, + grad_norm=grad_norm, + grad_max_sv=grad_max_sv, + grad_min_sv=grad_min_sv, + grad_condition=grad_cond, + lr=scheduler.get_last_lr()[0], + time_sec=dt, + ) + history.append(metrics) + + if epoch % 10 == 0 or epoch == epochs: + lyap_str = f"λ={lyap:.3f}" if lyap else "" + sv_str = f"σ={grad_max_sv:.2e}/{grad_min_sv:.2e}" if grad_max_sv else "" + print(f" Epoch {epoch:3d}: train={train_acc:.3f} test={test_acc:.3f} {lyap_str} {sv_str}") + + if np.isnan(train_loss): + print(f" DIVERGED at epoch {epoch}") + break + + print(f" Best test acc: {best_acc:.3f}") + return history + + +def run_depth_scaling_experiment( + dataset_name: str, + depth_configs: List[Tuple[int, int]], + train_loader: DataLoader, + test_loader: DataLoader, + num_classes: int, + in_channels: int, + T: int, + epochs: int, + lr: float, + lambda_reg: float, + lambda_target: float, + device: torch.device, + seed: int, + progress: bool, + reg_type: str = "squared", + warmup_epochs: int = 0, + stable_init: bool = False, + lyap_threshold: float = 2.0, +) -> Dict: + """Run full depth scaling experiment.""" + + results = {"vanilla": {}, "lyapunov": {}} + + print(f"Regularization type: {reg_type}") + print(f"Warmup epochs: {warmup_epochs}") + print(f"Stable init: {stable_init}") + print(f"Lyapunov threshold: {lyap_threshold}") + + for depth_config in depth_configs: + num_stages, blocks_per_stage = depth_config + total_depth = num_stages * blocks_per_stage + + print(f"\n{'='*60}") + print(f"Depth = {total_depth} conv layers ({num_stages} stages × {blocks_per_stage} blocks)") + print(f"{'='*60}") + + for use_lyap in [False, True]: + method = "lyapunov" if use_lyap else "vanilla" + + history = run_single_config( + dataset_name=dataset_name, + depth_config=depth_config, + use_lyapunov=use_lyap, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=T, + epochs=epochs, + lr=lr, + lambda_reg=lambda_reg, + lambda_target=lambda_target, + device=device, + seed=seed, + progress=progress, + reg_type=reg_type, + warmup_epochs=warmup_epochs, + stable_init=stable_init, + lyap_threshold=lyap_threshold, + ) + + results[method][total_depth] = history + + return results + + +def print_summary(results: Dict, dataset_name: str): + """Print final summary table.""" + print("\n" + "=" * 100) + print(f"DEPTH SCALING RESULTS: {dataset_name.upper()}") + print("=" * 100) + print(f"{'Depth':<8} {'Vanilla Acc':<12} {'Lyapunov Acc':<12} {'Δ Acc':<8} {'Lyap λ':<10} {'Van ∇norm':<12} {'Lyap ∇norm':<12} {'Van κ':<10}") + print("-" * 100) + + depths = sorted(results["vanilla"].keys()) + + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_acc = van.test_acc if not np.isnan(van.train_loss) else 0.0 + lyap_acc = lyap.test_acc if not np.isnan(lyap.train_loss) else 0.0 + + diff = lyap_acc - van_acc + diff_str = f"+{diff:.3f}" if diff >= 0 else f"{diff:.3f}" + + van_str = f"{van_acc:.3f}" if van_acc > 0 else "FAILED" + lyap_str = f"{lyap_acc:.3f}" if lyap_acc > 0 else "FAILED" + lyap_val = f"{lyap.lyapunov:.3f}" if lyap.lyapunov else "N/A" + + van_grad = f"{van.grad_norm:.2e}" if van.grad_norm else "N/A" + lyap_grad = f"{lyap.grad_norm:.2e}" if lyap.grad_norm else "N/A" + van_cond = f"{van.grad_condition:.1e}" if van.grad_condition else "N/A" + + print(f"{depth:<8} {van_str:<12} {lyap_str:<12} {diff_str:<8} {lyap_val:<10} {van_grad:<12} {lyap_grad:<12} {van_cond:<10}") + + print("=" * 100) + + # Gradient health analysis + print("\nGRADIENT HEALTH ANALYSIS:") + for depth in depths: + van = results["vanilla"][depth][-1] + lyap = results["lyapunov"][depth][-1] + + van_cond = van.grad_condition if van.grad_condition else 0 + lyap_cond = lyap.grad_condition if lyap.grad_condition else 0 + + status = "" + if van_cond > 1e6: + status = "⚠️ Vanilla has ill-conditioned gradients (κ > 1e6)" + elif van_cond > 1e4: + status = "~ Vanilla has moderately ill-conditioned gradients" + + if status: + print(f" Depth {depth}: {status}") + + print("") + + # Analysis + print("\nKEY OBSERVATIONS:") + shallow = min(depths) + deep = max(depths) + + van_shallow = results["vanilla"][shallow][-1].test_acc + van_deep = results["vanilla"][deep][-1].test_acc + lyap_shallow = results["lyapunov"][shallow][-1].test_acc + lyap_deep = results["lyapunov"][deep][-1].test_acc + + van_gain = van_deep - van_shallow + lyap_gain = lyap_deep - lyap_shallow + + print(f" Vanilla {shallow}→{deep} layers: {van_gain:+.3f} accuracy change") + print(f" Lyapunov {shallow}→{deep} layers: {lyap_gain:+.3f} accuracy change") + + if lyap_gain > van_gain + 0.02: + print(f" ✓ Lyapunov regularization enables better depth scaling!") + elif lyap_gain > van_gain: + print(f" ~ Lyapunov shows slight improvement in depth scaling") + else: + print(f" ✗ No clear benefit from Lyapunov on this dataset/depth range") + + +def save_results(results: Dict, output_dir: str, config: Dict): + os.makedirs(output_dir, exist_ok=True) + + serializable = {} + for method, depth_results in results.items(): + serializable[method] = {} + for depth, history in depth_results.items(): + serializable[method][str(depth)] = [asdict(m) for m in history] + + with open(os.path.join(output_dir, "results.json"), "w") as f: + json.dump(serializable, f, indent=2) + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + print(f"\nResults saved to {output_dir}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Depth Scaling Benchmark for Lyapunov-Regularized SNNs") + + p.add_argument("--dataset", type=str, default="cifar100", + choices=["mnist", "fashion_mnist", "cifar10", "cifar100"]) + p.add_argument("--depths", type=int, nargs="+", default=[4, 8, 12, 16], + help="Total conv layer depths to test") + p.add_argument("--T", type=int, default=4, help="Timesteps") + p.add_argument("--epochs", type=int, default=100) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--lambda_reg", type=float, default=0.3) + p.add_argument("--lambda_target", type=float, default=-0.1) + p.add_argument("--data_dir", type=str, default="./data") + p.add_argument("--out_dir", type=str, default="runs/depth_scaling") + p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--no-progress", action="store_true") + p.add_argument("--reg_type", type=str, default="squared", + choices=["squared", "hinge", "asymmetric", "extreme"], + help="Lyapunov regularization type") + p.add_argument("--warmup_epochs", type=int, default=0, + help="Epochs to warmup lambda_reg (0 = no warmup)") + p.add_argument("--stable_init", action="store_true", + help="Use stability-aware weight initialization") + p.add_argument("--lyap_threshold", type=float, default=2.0, + help="Threshold for extreme reg_type (only penalize λ > threshold)") + + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + print("=" * 80) + print("DEPTH SCALING BENCHMARK") + print("=" * 80) + print(f"Dataset: {args.dataset}") + print(f"Depths: {args.depths}") + print(f"Timesteps: {args.T}") + print(f"Epochs: {args.epochs}") + print(f"λ_reg: {args.lambda_reg}, λ_target: {args.lambda_target}") + print(f"Reg type: {args.reg_type}, Warmup epochs: {args.warmup_epochs}") + print(f"Device: {device}") + print("=" * 80) + + # Load data + print(f"\nLoading {args.dataset}...") + train_loader, test_loader, num_classes, input_shape = get_dataset( + args.dataset, args.data_dir, args.batch_size + ) + in_channels = input_shape[0] + print(f"Classes: {num_classes}, Input: {input_shape}") + print(f"Train: {len(train_loader.dataset)}, Test: {len(test_loader.dataset)}") + + # Convert depths to (num_stages, blocks_per_stage) configs + # We use 4 stages (3 for smaller nets), adjust blocks_per_stage + depth_configs = [] + for d in args.depths: + if d <= 4: + depth_configs.append((d, 1)) # d stages, 1 block each + elif d <= 8: + depth_configs.append((4, d // 4)) # 4 stages + else: + depth_configs.append((4, d // 4)) # 4 stages, more blocks + + print(f"\nDepth configurations: {[(d, f'{s}×{b}') for d, (s, b) in zip(args.depths, depth_configs)]}") + + # Run experiment + results = run_depth_scaling_experiment( + dataset_name=args.dataset, + depth_configs=depth_configs, + train_loader=train_loader, + test_loader=test_loader, + num_classes=num_classes, + in_channels=in_channels, + T=args.T, + epochs=args.epochs, + lr=args.lr, + lambda_reg=args.lambda_reg, + lambda_target=args.lambda_target, + device=device, + seed=args.seed, + progress=not args.no_progress, + reg_type=args.reg_type, + warmup_epochs=args.warmup_epochs, + stable_init=args.stable_init, + lyap_threshold=args.lyap_threshold, + ) + + # Summary + print_summary(results, args.dataset) + + # Save + ts = time.strftime("%Y%m%d-%H%M%S") + output_dir = os.path.join(args.out_dir, f"{args.dataset}_{ts}") + save_results(results, output_dir, vars(args)) + + +if __name__ == "__main__": + main() -- cgit v1.2.3