""" Convolutional SNN with Lyapunov regularization for image classification. Properly handles spatial structure: - Input: (B, C, H, W) static image OR (B, T, C, H, W) spike tensor - Uses Conv-LIF layers to preserve spatial hierarchy - Rate encoding converts images to spike trains Based on standard SNN vision practices: - Rate/Poisson encoding for input - Conv → BatchNorm → LIF → Pool architecture - Time comes from encoding + LIF dynamics, not flattening """ from typing import Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import snntorch as snn from snntorch import surrogate class RateEncoder(nn.Module): """ Rate (Poisson/Bernoulli) encoder for static images. Converts intensity x ∈ [0,1] to spike probability per timestep. Each pixel independently fires with P(spike) = x * gain. Args: T: Number of timesteps gain: Scaling factor for firing probability (default 1.0) Input: (B, C, H, W) normalized image in [0, 1] Output: (B, T, C, H, W) binary spike tensor """ def __init__(self, T: int = 25, gain: float = 1.0): super().__init__() self.T = T self.gain = gain def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (B, C, H, W) image tensor, values in [0, 1] Returns: spikes: (B, T, C, H, W) binary spike tensor """ # Clamp to valid probability range prob = (x * self.gain).clamp(0, 1) # Expand for T timesteps: (B, C, H, W) -> (B, T, C, H, W) prob = prob.unsqueeze(1).expand(-1, self.T, -1, -1, -1) # Sample spikes spikes = torch.bernoulli(prob) return spikes class DirectEncoder(nn.Module): """ Direct encoding - feed static image as constant current. Common in surrogate gradient papers: no spike encoding at input, let spiking emerge from the network dynamics. Input: (B, C, H, W) image Output: (B, T, C, H, W) repeated image (as analog current) """ def __init__(self, T: int = 25): super().__init__() self.T = T def forward(self, x: torch.Tensor) -> torch.Tensor: # Simply repeat across time return x.unsqueeze(1).expand(-1, self.T, -1, -1, -1) class ConvLIFBlock(nn.Module): """ Conv → BatchNorm → LIF block. Maintains spatial structure while adding spiking dynamics. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, beta: float = 0.9, threshold: float = 1.0, spike_grad=None, ): super().__init__() if spike_grad is None: spike_grad = surrogate.fast_sigmoid(slope=25) self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, ) self.bn = nn.BatchNorm2d(out_channels) self.lif = snn.Leaky( beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False, ) def forward( self, x: torch.Tensor, mem: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: x: (B, C_in, H, W) input (spikes or current) mem: (B, C_out, H', W') membrane potential Returns: spk: (B, C_out, H', W') output spikes mem: (B, C_out, H', W') updated membrane """ cur = self.bn(self.conv(x)) spk, mem = self.lif(cur, mem) return spk, mem class ConvLyapunovSNN(nn.Module): """ Convolutional SNN with Lyapunov exponent regularization. Architecture for CIFAR-10 (32x32x3): Input → Encoder → [Conv-LIF-Pool] × N → FC → Output Properly preserves spatial structure for hierarchical feature learning. Args: in_channels: Input channels (3 for RGB) num_classes: Output classes channels: List of channel sizes for conv layers T: Number of timesteps beta: LIF membrane decay threshold: LIF firing threshold encoding: 'rate', 'direct', or 'none' (pre-encoded input) encoding_gain: Gain for rate encoding """ def __init__( self, in_channels: int = 3, num_classes: int = 10, channels: List[int] = [64, 128, 256], T: int = 25, beta: float = 0.9, threshold: float = 1.0, encoding: str = 'rate', encoding_gain: float = 1.0, dropout: float = 0.2, ): super().__init__() self.T = T self.encoding_type = encoding self.channels = channels self.num_layers = len(channels) # Input encoder if encoding == 'rate': self.encoder = RateEncoder(T=T, gain=encoding_gain) elif encoding == 'direct': self.encoder = DirectEncoder(T=T) else: self.encoder = None # Expect pre-encoded (B, T, C, H, W) input # Build conv-LIF layers self.blocks = nn.ModuleList() self.pools = nn.ModuleList() ch_in = in_channels for ch_out in channels: self.blocks.append( ConvLIFBlock(ch_in, ch_out, beta=beta, threshold=threshold) ) self.pools.append(nn.AvgPool2d(2)) ch_in = ch_out # Calculate output spatial size after pooling # CIFAR: 32 -> 16 -> 8 -> 4 (for 3 layers) spatial_size = 32 // (2 ** len(channels)) fc_input = channels[-1] * spatial_size * spatial_size # Fully connected readout self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(fc_input, num_classes) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def _init_mem(self, batch_size: int, device, dtype) -> List[torch.Tensor]: """Initialize membrane potentials for all layers.""" mems = [] H, W = 32, 32 for i, ch in enumerate(self.channels): H, W = H // 2, W // 2 # After pooling # Actually we need size BEFORE pooling for LIF H_pre, W_pre = H * 2, W * 2 mems.append(torch.zeros(batch_size, ch, H_pre, W_pre, device=device, dtype=dtype)) return mems def forward( self, x: torch.Tensor, compute_lyapunov: bool = False, lyap_eps: float = 1e-4, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: """ Forward pass with optional Lyapunov computation. Args: x: Input tensor - If encoder: (B, C, H, W) static image - If no encoder: (B, T, C, H, W) pre-encoded spikes compute_lyapunov: Whether to compute Lyapunov exponent lyap_eps: Perturbation magnitude Returns: logits: (B, num_classes) lyap_est: Scalar Lyapunov estimate or None recordings: Optional dict with spike recordings """ # Encode input if needed if self.encoder is not None: x = self.encoder(x) # (B, C, H, W) -> (B, T, C, H, W) B, T, C, H, W = x.shape device, dtype = x.device, x.dtype # Initialize membrane potentials mems = self._init_mem(B, device, dtype) # For accumulating output spikes spike_sum = None # Lyapunov setup if compute_lyapunov: mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] lyap_accum = torch.zeros(B, device=device, dtype=dtype) # Time loop for t in range(T): x_t = x[:, t] # (B, C, H, W) # Forward through conv-LIF blocks h = x_t new_mems = [] for i, (block, pool) in enumerate(zip(self.blocks, self.pools)): h, mem = block(h, mems[i]) new_mems.append(mem) h = pool(h) # Spatial downsampling mems = new_mems # Accumulate final layer spikes if spike_sum is None: spike_sum = h.view(B, -1) else: spike_sum = spike_sum + h.view(B, -1) # Lyapunov computation if compute_lyapunov: h_p = x_t new_mems_p = [] for i, (block, pool) in enumerate(zip(self.blocks, self.pools)): h_p, mem_p = block(h_p, mems_p[i]) new_mems_p.append(mem_p) h_p = pool(h_p) # Compute divergence delta_sq = torch.zeros(B, device=device, dtype=dtype) for i in range(self.num_layers): diff = new_mems_p[i] - new_mems[i] delta_sq += (diff ** 2).sum(dim=(1, 2, 3)) delta = torch.sqrt(delta_sq + 1e-12) lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) # Renormalize perturbation for i in range(self.num_layers): diff = new_mems_p[i] - new_mems[i] norm = torch.sqrt((diff ** 2).sum(dim=(1, 2, 3), keepdim=True) + 1e-12) # Broadcast norm to spatial dimensions norm = norm.view(B, 1, 1, 1) new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm mems_p = new_mems_p # Readout out = self.dropout(spike_sum) logits = self.fc(out) if compute_lyapunov: lyap_est = (lyap_accum / T).mean() else: lyap_est = None return logits, lyap_est, None class VGGLyapunovSNN(nn.Module): """ VGG-style deep Conv-SNN with Lyapunov regularization. Deeper architecture for more challenging benchmarks. Uses multiple conv layers between pooling to increase depth. Architecture (VGG-9 style): [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → FC """ def __init__( self, in_channels: int = 3, num_classes: int = 10, T: int = 25, beta: float = 0.9, threshold: float = 1.0, encoding: str = 'rate', dropout: float = 0.3, ): super().__init__() self.T = T self.encoding_type = encoding spike_grad = surrogate.fast_sigmoid(slope=25) if encoding == 'rate': self.encoder = RateEncoder(T=T) elif encoding == 'direct': self.encoder = DirectEncoder(T=T) else: self.encoder = None # VGG-style blocks: (in_ch, out_ch, num_convs) block_configs = [ (in_channels, 64, 2), # 32x32 -> 16x16 (64, 128, 2), # 16x16 -> 8x8 (128, 256, 2), # 8x8 -> 4x4 ] self.blocks = nn.ModuleList() for in_ch, out_ch, n_convs in block_configs: layers = [] for i in range(n_convs): ch_in = in_ch if i == 0 else out_ch layers.append(nn.Conv2d(ch_in, out_ch, 3, padding=1, bias=False)) layers.append(nn.BatchNorm2d(out_ch)) self.blocks.append(nn.ModuleList(layers)) # LIF neurons for each conv layer self.lifs = nn.ModuleList([ snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) for _ in range(6) # 2 convs × 3 blocks ]) self.pools = nn.ModuleList([nn.AvgPool2d(2) for _ in range(3)]) # FC layers self.fc1 = nn.Linear(256 * 4 * 4, 512) self.lif_fc = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) self.dropout = nn.Dropout(dropout) self.fc2 = nn.Linear(512, num_classes) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward( self, x: torch.Tensor, compute_lyapunov: bool = False, lyap_eps: float = 1e-4, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: if self.encoder is not None: x = self.encoder(x) B, T, C, H, W = x.shape device, dtype = x.device, x.dtype # Initialize all membrane potentials # For each conv layer output mem_shapes = [ (B, 64, 32, 32), (B, 64, 32, 32), # Block 1 (B, 128, 16, 16), (B, 128, 16, 16), # Block 2 (B, 256, 8, 8), (B, 256, 8, 8), # Block 3 (B, 512), # FC ] mems = [torch.zeros(s, device=device, dtype=dtype) for s in mem_shapes] spike_sum = torch.zeros(B, 512, device=device, dtype=dtype) if compute_lyapunov: mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] lyap_accum = torch.zeros(B, device=device, dtype=dtype) for t in range(T): h = x[:, t] lif_idx = 0 for block_idx, (block_layers, pool) in enumerate(zip(self.blocks, self.pools)): for i in range(0, len(block_layers), 2): # Conv, BN pairs conv, bn = block_layers[i], block_layers[i + 1] h = bn(conv(h)) h, mems[lif_idx] = self.lifs[lif_idx](h, mems[lif_idx]) lif_idx += 1 h = pool(h) # FC layers h = h.view(B, -1) h = self.fc1(h) h, mems[6] = self.lif_fc(h, mems[6]) spike_sum = spike_sum + h # Lyapunov (simplified - just on last layer) if compute_lyapunov: diff = mems[6] - mems_p[6] if t > 0 else torch.zeros_like(mems[6]) delta = torch.norm(diff.view(B, -1), dim=1) + 1e-12 if t > 0: lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) mems_p[6] = mems[6] + lyap_eps * torch.randn_like(mems[6]) out = self.dropout(spike_sum) logits = self.fc2(out) lyap_est = (lyap_accum / T).mean() if compute_lyapunov else None return logits, lyap_est, None def create_conv_snn( model_type: str = 'simple', **kwargs, ) -> nn.Module: """ Factory function for Conv-SNN models. Args: model_type: 'simple' (3-layer) or 'vgg' (6-layer VGG-style) **kwargs: Arguments passed to model constructor """ if model_type == 'simple': return ConvLyapunovSNN(**kwargs) elif model_type == 'vgg': return VGGLyapunovSNN(**kwargs) else: raise ValueError(f"Unknown model_type: {model_type}")