diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/models | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files/models')
| -rw-r--r-- | files/models/conv_snn.py | 483 | ||||
| -rw-r--r-- | files/models/snn.py | 141 | ||||
| -rw-r--r-- | files/models/snn_snntorch.py | 398 |
3 files changed, 1022 insertions, 0 deletions
diff --git a/files/models/conv_snn.py b/files/models/conv_snn.py new file mode 100644 index 0000000..69f77d6 --- /dev/null +++ b/files/models/conv_snn.py @@ -0,0 +1,483 @@ +""" +Convolutional SNN with Lyapunov regularization for image classification. + +Properly handles spatial structure: +- Input: (B, C, H, W) static image OR (B, T, C, H, W) spike tensor +- Uses Conv-LIF layers to preserve spatial hierarchy +- Rate encoding converts images to spike trains + +Based on standard SNN vision practices: +- Rate/Poisson encoding for input +- Conv → BatchNorm → LIF → Pool architecture +- Time comes from encoding + LIF dynamics, not flattening +""" + +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import snntorch as snn +from snntorch import surrogate + + +class RateEncoder(nn.Module): + """ + Rate (Poisson/Bernoulli) encoder for static images. + + Converts intensity x ∈ [0,1] to spike probability per timestep. + Each pixel independently fires with P(spike) = x * gain. + + Args: + T: Number of timesteps + gain: Scaling factor for firing probability (default 1.0) + + Input: (B, C, H, W) normalized image in [0, 1] + Output: (B, T, C, H, W) binary spike tensor + """ + + def __init__(self, T: int = 25, gain: float = 1.0): + super().__init__() + self.T = T + self.gain = gain + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, C, H, W) image tensor, values in [0, 1] + Returns: + spikes: (B, T, C, H, W) binary spike tensor + """ + # Clamp to valid probability range + prob = (x * self.gain).clamp(0, 1) + + # Expand for T timesteps: (B, C, H, W) -> (B, T, C, H, W) + prob = prob.unsqueeze(1).expand(-1, self.T, -1, -1, -1) + + # Sample spikes + spikes = torch.bernoulli(prob) + + return spikes + + +class DirectEncoder(nn.Module): + """ + Direct encoding - feed static image as constant current. + + Common in surrogate gradient papers: no spike encoding at input, + let spiking emerge from the network dynamics. + + Input: (B, C, H, W) image + Output: (B, T, C, H, W) repeated image (as analog current) + """ + + def __init__(self, T: int = 25): + super().__init__() + self.T = T + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Simply repeat across time + return x.unsqueeze(1).expand(-1, self.T, -1, -1, -1) + + +class ConvLIFBlock(nn.Module): + """ + Conv → BatchNorm → LIF block. + + Maintains spatial structure while adding spiking dynamics. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + beta: float = 0.9, + threshold: float = 1.0, + spike_grad=None, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d( + in_channels, out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.lif = snn.Leaky( + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + ) + + def forward( + self, + x: torch.Tensor, + mem: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: (B, C_in, H, W) input (spikes or current) + mem: (B, C_out, H', W') membrane potential + Returns: + spk: (B, C_out, H', W') output spikes + mem: (B, C_out, H', W') updated membrane + """ + cur = self.bn(self.conv(x)) + spk, mem = self.lif(cur, mem) + return spk, mem + + +class ConvLyapunovSNN(nn.Module): + """ + Convolutional SNN with Lyapunov exponent regularization. + + Architecture for CIFAR-10 (32x32x3): + Input → Encoder → [Conv-LIF-Pool] × N → FC → Output + + Properly preserves spatial structure for hierarchical feature learning. + + Args: + in_channels: Input channels (3 for RGB) + num_classes: Output classes + channels: List of channel sizes for conv layers + T: Number of timesteps + beta: LIF membrane decay + threshold: LIF firing threshold + encoding: 'rate', 'direct', or 'none' (pre-encoded input) + encoding_gain: Gain for rate encoding + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + channels: List[int] = [64, 128, 256], + T: int = 25, + beta: float = 0.9, + threshold: float = 1.0, + encoding: str = 'rate', + encoding_gain: float = 1.0, + dropout: float = 0.2, + ): + super().__init__() + + self.T = T + self.encoding_type = encoding + self.channels = channels + self.num_layers = len(channels) + + # Input encoder + if encoding == 'rate': + self.encoder = RateEncoder(T=T, gain=encoding_gain) + elif encoding == 'direct': + self.encoder = DirectEncoder(T=T) + else: + self.encoder = None # Expect pre-encoded (B, T, C, H, W) input + + # Build conv-LIF layers + self.blocks = nn.ModuleList() + self.pools = nn.ModuleList() + + ch_in = in_channels + for ch_out in channels: + self.blocks.append( + ConvLIFBlock(ch_in, ch_out, beta=beta, threshold=threshold) + ) + self.pools.append(nn.AvgPool2d(2)) + ch_in = ch_out + + # Calculate output spatial size after pooling + # CIFAR: 32 -> 16 -> 8 -> 4 (for 3 layers) + spatial_size = 32 // (2 ** len(channels)) + fc_input = channels[-1] * spatial_size * spatial_size + + # Fully connected readout + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(fc_input, num_classes) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_mem(self, batch_size: int, device, dtype) -> List[torch.Tensor]: + """Initialize membrane potentials for all layers.""" + mems = [] + H, W = 32, 32 + for i, ch in enumerate(self.channels): + H, W = H // 2, W // 2 # After pooling + # Actually we need size BEFORE pooling for LIF + H_pre, W_pre = H * 2, W * 2 + mems.append(torch.zeros(batch_size, ch, H_pre, W_pre, device=device, dtype=dtype)) + return mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + """ + Forward pass with optional Lyapunov computation. + + Args: + x: Input tensor + - If encoder: (B, C, H, W) static image + - If no encoder: (B, T, C, H, W) pre-encoded spikes + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude + + Returns: + logits: (B, num_classes) + lyap_est: Scalar Lyapunov estimate or None + recordings: Optional dict with spike recordings + """ + # Encode input if needed + if self.encoder is not None: + x = self.encoder(x) # (B, C, H, W) -> (B, T, C, H, W) + + B, T, C, H, W = x.shape + device, dtype = x.device, x.dtype + + # Initialize membrane potentials + mems = self._init_mem(B, device, dtype) + + # For accumulating output spikes + spike_sum = None + + # Lyapunov setup + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + # Time loop + for t in range(T): + x_t = x[:, t] # (B, C, H, W) + + # Forward through conv-LIF blocks + h = x_t + new_mems = [] + for i, (block, pool) in enumerate(zip(self.blocks, self.pools)): + h, mem = block(h, mems[i]) + new_mems.append(mem) + h = pool(h) # Spatial downsampling + + mems = new_mems + + # Accumulate final layer spikes + if spike_sum is None: + spike_sum = h.view(B, -1) + else: + spike_sum = spike_sum + h.view(B, -1) + + # Lyapunov computation + if compute_lyapunov: + h_p = x_t + new_mems_p = [] + for i, (block, pool) in enumerate(zip(self.blocks, self.pools)): + h_p, mem_p = block(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = pool(h_p) + + # Compute divergence + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(self.num_layers): + diff = new_mems_p[i] - new_mems[i] + delta_sq += (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize perturbation + for i in range(self.num_layers): + diff = new_mems_p[i] - new_mems[i] + norm = torch.sqrt((diff ** 2).sum(dim=(1, 2, 3), keepdim=True) + 1e-12) + # Broadcast norm to spatial dimensions + norm = norm.view(B, 1, 1, 1) + new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm + + mems_p = new_mems_p + + # Readout + out = self.dropout(spike_sum) + logits = self.fc(out) + + if compute_lyapunov: + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + return logits, lyap_est, None + + +class VGGLyapunovSNN(nn.Module): + """ + VGG-style deep Conv-SNN with Lyapunov regularization. + + Deeper architecture for more challenging benchmarks. + Uses multiple conv layers between pooling to increase depth. + + Architecture (VGG-9 style): + [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → FC + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + T: int = 25, + beta: float = 0.9, + threshold: float = 1.0, + encoding: str = 'rate', + dropout: float = 0.3, + ): + super().__init__() + + self.T = T + self.encoding_type = encoding + + spike_grad = surrogate.fast_sigmoid(slope=25) + + if encoding == 'rate': + self.encoder = RateEncoder(T=T) + elif encoding == 'direct': + self.encoder = DirectEncoder(T=T) + else: + self.encoder = None + + # VGG-style blocks: (in_ch, out_ch, num_convs) + block_configs = [ + (in_channels, 64, 2), # 32x32 -> 16x16 + (64, 128, 2), # 16x16 -> 8x8 + (128, 256, 2), # 8x8 -> 4x4 + ] + + self.blocks = nn.ModuleList() + for in_ch, out_ch, n_convs in block_configs: + layers = [] + for i in range(n_convs): + ch_in = in_ch if i == 0 else out_ch + layers.append(nn.Conv2d(ch_in, out_ch, 3, padding=1, bias=False)) + layers.append(nn.BatchNorm2d(out_ch)) + self.blocks.append(nn.ModuleList(layers)) + + # LIF neurons for each conv layer + self.lifs = nn.ModuleList([ + snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + for _ in range(6) # 2 convs × 3 blocks + ]) + + self.pools = nn.ModuleList([nn.AvgPool2d(2) for _ in range(3)]) + + # FC layers + self.fc1 = nn.Linear(256 * 4 * 4, 512) + self.lif_fc = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + self.dropout = nn.Dropout(dropout) + self.fc2 = nn.Linear(512, num_classes) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + + if self.encoder is not None: + x = self.encoder(x) + + B, T, C, H, W = x.shape + device, dtype = x.device, x.dtype + + # Initialize all membrane potentials + # For each conv layer output + mem_shapes = [ + (B, 64, 32, 32), (B, 64, 32, 32), # Block 1 + (B, 128, 16, 16), (B, 128, 16, 16), # Block 2 + (B, 256, 8, 8), (B, 256, 8, 8), # Block 3 + (B, 512), # FC + ] + mems = [torch.zeros(s, device=device, dtype=dtype) for s in mem_shapes] + + spike_sum = torch.zeros(B, 512, device=device, dtype=dtype) + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + for t in range(T): + h = x[:, t] + + lif_idx = 0 + for block_idx, (block_layers, pool) in enumerate(zip(self.blocks, self.pools)): + for i in range(0, len(block_layers), 2): # Conv, BN pairs + conv, bn = block_layers[i], block_layers[i + 1] + h = bn(conv(h)) + h, mems[lif_idx] = self.lifs[lif_idx](h, mems[lif_idx]) + lif_idx += 1 + h = pool(h) + + # FC layers + h = h.view(B, -1) + h = self.fc1(h) + h, mems[6] = self.lif_fc(h, mems[6]) + spike_sum = spike_sum + h + + # Lyapunov (simplified - just on last layer) + if compute_lyapunov: + diff = mems[6] - mems_p[6] if t > 0 else torch.zeros_like(mems[6]) + delta = torch.norm(diff.view(B, -1), dim=1) + 1e-12 + if t > 0: + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + mems_p[6] = mems[6] + lyap_eps * torch.randn_like(mems[6]) + + out = self.dropout(spike_sum) + logits = self.fc2(out) + + lyap_est = (lyap_accum / T).mean() if compute_lyapunov else None + + return logits, lyap_est, None + + +def create_conv_snn( + model_type: str = 'simple', + **kwargs, +) -> nn.Module: + """ + Factory function for Conv-SNN models. + + Args: + model_type: 'simple' (3-layer) or 'vgg' (6-layer VGG-style) + **kwargs: Arguments passed to model constructor + """ + if model_type == 'simple': + return ConvLyapunovSNN(**kwargs) + elif model_type == 'vgg': + return VGGLyapunovSNN(**kwargs) + else: + raise ValueError(f"Unknown model_type: {model_type}") diff --git a/files/models/snn.py b/files/models/snn.py new file mode 100644 index 0000000..b1cf633 --- /dev/null +++ b/files/models/snn.py @@ -0,0 +1,141 @@ +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SurrogateStep(torch.autograd.Function): + """ + Heaviside with a smooth surrogate gradient (fast sigmoid). + """ + @staticmethod + def forward(ctx, x: torch.Tensor, alpha: float): + ctx.save_for_backward(x) + ctx.alpha = alpha + return (x > 0).to(x.dtype) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + alpha = ctx.alpha + # d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid) + # Use fast sigmoid: s = 1 / (1 + |alpha*x|) + s = 1.0 / (1.0 + (alpha * x).abs()) + grad = grad_output * s * s + return grad, None + + +def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor: + return SurrogateStep.apply(x, alpha) + + +class LIFLayer(nn.Module): + """ + Single LIF layer without recurrent synapses between neurons. + Dynamics per neuron i: + v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1} + s_t = H( v_t - v_th ) with surrogate gradient + """ + def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0): + super().__init__() + self.linear = nn.Linear(input_dim, hidden_dim, bias=True) + self.v_threshold = float(v_threshold) + self.decay = float(decay) + self.spike_alpha = float(spike_alpha) + self.rec_strength = float(rec_strength) + self.rec = None + if self.rec_strength != 0.0: + self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False) + + nn.init.xavier_uniform_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + if self.rec is not None: + nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale) + + def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H) + I_t = self.linear(x_t) # (B, H) + R_t = 0.0 + if self.rec is not None: + R_t = self.rec_strength * self.rec(s_prev) + v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev + s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha) + return v_t, s_t + + +class SimpleSNN(nn.Module): + """ + Minimal SNN for SHD-like input (B,T,D): + - One LIF hidden layer + - Readout linear on time-summed spikes + """ + def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0): + super().__init__() + self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale) + self.readout = nn.Linear(hidden_dim, num_classes) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + @torch.no_grad() + def _init_states(self, batch_size: int, hidden_dim: int, device, dtype): + v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) + s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) + return v0, s0 + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-3, + lyap_safe_eps: float = 1e-8, + lyap_measure: str = "v", # "v" or "s" + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + x: (B, T, D) + Returns: + logits: (B, C) + lyap_est: scalar tensor if compute_lyapunov else None + """ + assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}" + B, T, D = x.shape + device, dtype = x.device, x.dtype + H = self.readout.in_features + + v, s = self._init_states(B, H, device, dtype) + spike_sum = torch.zeros(B, H, device=device, dtype=dtype) + + if compute_lyapunov: + v_p = v + lyap_eps * torch.randn_like(v) + s_p = s.clone() + delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps + lyap_terms = [] + + for t in range(T): + x_t = x[:, t, :] + v, s = self.lif(x_t, v, s) + spike_sum = spike_sum + s + + if compute_lyapunov: + # run perturbed trajectory through same ops + v_p, s_p = self.lif(x_t, v_p, s_p) + if lyap_measure == "s": + delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps + else: + delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps + ratio = delta_t / delta_prev + lyap_terms.append(torch.log(ratio + lyap_safe_eps)) + delta_prev = delta_t + + logits = self.readout(spike_sum) # (B, C) + + if compute_lyapunov: + lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0) # (B,) + lyap_est = lyap_batch.mean() # scalar + else: + lyap_est = None + + return logits, lyap_est + + diff --git a/files/models/snn_snntorch.py b/files/models/snn_snntorch.py new file mode 100644 index 0000000..71c1c18 --- /dev/null +++ b/files/models/snn_snntorch.py @@ -0,0 +1,398 @@ +""" +snnTorch-based SNN with Lyapunov exponent regularization. + +This module provides deep SNN architectures using snnTorch with proper +finite-time Lyapunov exponent computation for training stabilization. +""" + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import snntorch as snn +from snntorch import surrogate + + +class LyapunovSNN(nn.Module): + """ + Multi-layer SNN using snnTorch with Lyapunov exponent computation. + + Architecture: + Input (B, T, D) -> [LIF layers] -> time-summed spikes -> Linear -> logits + + Args: + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes (e.g., [256, 128] for 2 layers) + num_classes: Number of output classes + beta: Membrane potential decay factor (0 < beta < 1) + threshold: Firing threshold + spike_grad: Surrogate gradient function (default: fast_sigmoid) + dropout: Dropout probability between layers (0 = no dropout) + """ + + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + beta: float = 0.9, + threshold: float = 1.0, + spike_grad: Optional[Any] = None, + dropout: float = 0.0, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.hidden_dims = hidden_dims + self.num_layers = len(hidden_dims) + self.beta = beta + self.threshold = threshold + + # Build layers + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + self.dropouts = nn.ModuleList() if dropout > 0 else None + + dims = [input_dim] + hidden_dims + for i in range(self.num_layers): + self.linears.append(nn.Linear(dims[i], dims[i + 1])) + self.lifs.append( + snn.Leaky( + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + reset_mechanism="subtract", + ) + ) + if dropout > 0: + self.dropouts.append(nn.Dropout(p=dropout)) + + # Readout layer + self.readout = nn.Linear(hidden_dims[-1], num_classes) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + for lin in self.linears: + nn.init.xavier_uniform_(lin.weight) + nn.init.zeros_(lin.bias) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + def _init_states(self, batch_size: int, device, dtype) -> List[torch.Tensor]: + """Initialize membrane potentials for all layers.""" + mems = [] + for dim in self.hidden_dims: + mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + return mems + + def _step( + self, + x_t: torch.Tensor, + mems: List[torch.Tensor], + training: bool = True, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Single timestep forward pass. + + Returns: + spike_out: Output spikes from last layer (B, H_last) + new_mems: Updated membrane potentials + all_mems: Membrane potentials from all layers (for Lyapunov) + """ + new_mems = [] + all_mems = [] + + h = x_t + for i in range(self.num_layers): + h = self.linears[i](h) + spk, mem = self.lifs[i](h, mems[i]) + new_mems.append(mem) + all_mems.append(mem) + h = spk + if self.dropouts is not None and training: + h = self.dropouts[i](h) + + return h, new_mems, all_mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + lyap_layers: Optional[List[int]] = None, + record_states: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict[str, torch.Tensor]]]: + """ + Forward pass with optional Lyapunov exponent computation. + + Args: + x: Input tensor (B, T, D) + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude for Lyapunov computation + lyap_layers: Which layers to measure (default: all). + e.g., [0] for first layer only, [-1] for last layer + record_states: Whether to record spikes and membrane potentials + + Returns: + logits: Classification logits (B, num_classes) + lyap_est: Estimated Lyapunov exponent (scalar) or None + recordings: Dict with 'spikes' (B,T,H) and 'membrane' (B,T,H) or None + """ + B, T, D = x.shape + device, dtype = x.device, x.dtype + + # Initialize states + mems = self._init_states(B, device, dtype) + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + # Recording setup + if record_states: + spike_rec = [] + mem_rec = [] + + # Lyapunov setup + if compute_lyapunov: + if lyap_layers is None: + lyap_layers = list(range(self.num_layers)) + + # Perturbed trajectory - perturb all membrane potentials + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + # Time loop + for t in range(T): + x_t = x[:, t, :] + + # Nominal trajectory + spk, mems, all_mems = self._step(x_t, mems, training=self.training) + spike_sum = spike_sum + spk + + if record_states: + spike_rec.append(spk.detach()) + mem_rec.append(all_mems[-1].detach()) # Last layer membrane + + if compute_lyapunov: + # Perturbed trajectory + _, mems_p, all_mems_p = self._step(x_t, mems_p, training=False) + + # Compute divergence across selected layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + delta_p_sq = torch.zeros(B, device=device, dtype=dtype) + + for layer_idx in lyap_layers: + diff = all_mems_p[layer_idx] - all_mems[layer_idx] + delta_sq += (diff ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + + # Renormalization step (key for numerical stability) + # Rescale perturbation back to fixed magnitude + for layer_idx in lyap_layers: + diff = mems_p[layer_idx] - mems[layer_idx] + # Normalize to maintain fixed perturbation magnitude + norm = torch.norm(diff.reshape(B, -1), dim=1, keepdim=True) + 1e-12 + diff_normalized = diff / norm.unsqueeze(-1) if diff.ndim > 2 else diff / norm + mems_p[layer_idx] = mems[layer_idx] + lyap_eps * diff_normalized + + # Accumulate log-divergence + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + logits = self.readout(spike_sum) + + if compute_lyapunov: + # Average over time and batch + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + if record_states: + recordings = { + "spikes": torch.stack(spike_rec, dim=1), # (B, T, H) + "membrane": torch.stack(mem_rec, dim=1), # (B, T, H) + } + else: + recordings = None + + return logits, lyap_est, recordings + + +class RecurrentLyapunovSNN(nn.Module): + """ + Recurrent SNN with Lyapunov exponent computation. + + Uses snnTorch's RSynaptic (recurrent synaptic) neurons for + richer temporal dynamics. + + Args: + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes + num_classes: Number of output classes + alpha: Synaptic current decay rate + beta: Membrane potential decay rate + threshold: Firing threshold + """ + + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + alpha: float = 0.9, + beta: float = 0.85, + threshold: float = 1.0, + spike_grad: Optional[Any] = None, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.hidden_dims = hidden_dims + self.num_layers = len(hidden_dims) + self.alpha = alpha + self.beta = beta + + # Build layers with recurrent synaptic neurons + self.linears = nn.ModuleList() + self.neurons = nn.ModuleList() + + dims = [input_dim] + hidden_dims + for i in range(self.num_layers): + self.linears.append(nn.Linear(dims[i], dims[i + 1])) + self.neurons.append( + snn.RSynaptic( + alpha=alpha, + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + reset_mechanism="subtract", + all_to_all=True, + linear_features=dims[i + 1], + ) + ) + + self.readout = nn.Linear(hidden_dims[-1], num_classes) + self._init_weights() + + def _init_weights(self): + for lin in self.linears: + nn.init.xavier_uniform_(lin.weight) + nn.init.zeros_(lin.bias) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + def _init_states(self, batch_size: int, device, dtype): + """Initialize synaptic currents and membrane potentials.""" + syns = [] + mems = [] + for dim in self.hidden_dims: + syns.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + return syns, mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass with optional Lyapunov computation.""" + B, T, D = x.shape + device, dtype = x.device, x.dtype + + syns, mems = self._init_states(B, device, dtype) + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + if compute_lyapunov: + # Perturb both synaptic currents and membrane potentials + syns_p = [s + lyap_eps * torch.randn_like(s) for s in syns] + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + for t in range(T): + x_t = x[:, t, :] + + # Nominal trajectory + h = x_t + new_syns, new_mems = [], [] + for i in range(self.num_layers): + h = self.linears[i](h) + spk, syn, mem = self.neurons[i](h, syns[i], mems[i]) + new_syns.append(syn) + new_mems.append(mem) + h = spk + syns, mems = new_syns, new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory + h_p = x_t + new_syns_p, new_mems_p = [], [] + for i in range(self.num_layers): + h_p = self.linears[i](h_p) + spk_p, syn_p, mem_p = self.neurons[i](h_p, syns_p[i], mems_p[i]) + new_syns_p.append(syn_p) + new_mems_p.append(mem_p) + h_p = spk_p + + # Compute divergence (on membrane potentials) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(self.num_layers): + diff_m = new_mems_p[i] - new_mems[i] + diff_s = new_syns_p[i] - new_syns[i] + delta_sq += (diff_m ** 2).sum(dim=1) + (diff_s ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize perturbation + total_dim = sum(2 * d for d in self.hidden_dims) # syn + mem + scale = lyap_eps / (delta.unsqueeze(-1) + 1e-12) + + syns_p = [new_syns[i] + scale * (new_syns_p[i] - new_syns[i]) + for i in range(self.num_layers)] + mems_p = [new_mems[i] + scale * (new_mems_p[i] - new_mems[i]) + for i in range(self.num_layers)] + + logits = self.readout(spike_sum) + + if compute_lyapunov: + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + return logits, lyap_est + + +def create_snn( + model_type: str, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + **kwargs, +) -> nn.Module: + """ + Factory function to create SNN models. + + Args: + model_type: "feedforward" or "recurrent" + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes + num_classes: Number of output classes + **kwargs: Additional arguments passed to model constructor + + Returns: + SNN model instance + """ + if model_type == "feedforward": + return LyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs) + elif model_type == "recurrent": + return RecurrentLyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs) + else: + raise ValueError(f"Unknown model_type: {model_type}. Use 'feedforward' or 'recurrent'") |
