""" snnTorch-based SNN with Lyapunov exponent regularization. This module provides deep SNN architectures using snnTorch with proper finite-time Lyapunov exponent computation for training stabilization. """ from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn as nn import snntorch as snn from snntorch import surrogate class LyapunovSNN(nn.Module): """ Multi-layer SNN using snnTorch with Lyapunov exponent computation. Architecture: Input (B, T, D) -> [LIF layers] -> time-summed spikes -> Linear -> logits Args: input_dim: Input feature dimension hidden_dims: List of hidden layer sizes (e.g., [256, 128] for 2 layers) num_classes: Number of output classes beta: Membrane potential decay factor (0 < beta < 1) threshold: Firing threshold spike_grad: Surrogate gradient function (default: fast_sigmoid) dropout: Dropout probability between layers (0 = no dropout) """ def __init__( self, input_dim: int, hidden_dims: List[int], num_classes: int, beta: float = 0.9, threshold: float = 1.0, spike_grad: Optional[Any] = None, dropout: float = 0.0, ): super().__init__() if spike_grad is None: spike_grad = surrogate.fast_sigmoid(slope=25) self.hidden_dims = hidden_dims self.num_layers = len(hidden_dims) self.beta = beta self.threshold = threshold # Build layers self.linears = nn.ModuleList() self.lifs = nn.ModuleList() self.dropouts = nn.ModuleList() if dropout > 0 else None dims = [input_dim] + hidden_dims for i in range(self.num_layers): self.linears.append(nn.Linear(dims[i], dims[i + 1])) self.lifs.append( snn.Leaky( beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False, reset_mechanism="subtract", ) ) if dropout > 0: self.dropouts.append(nn.Dropout(p=dropout)) # Readout layer self.readout = nn.Linear(hidden_dims[-1], num_classes) # Initialize weights self._init_weights() def _init_weights(self): for lin in self.linears: nn.init.xavier_uniform_(lin.weight) nn.init.zeros_(lin.bias) nn.init.xavier_uniform_(self.readout.weight) nn.init.zeros_(self.readout.bias) def _init_states(self, batch_size: int, device, dtype) -> List[torch.Tensor]: """Initialize membrane potentials for all layers.""" mems = [] for dim in self.hidden_dims: mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) return mems def _step( self, x_t: torch.Tensor, mems: List[torch.Tensor], training: bool = True, ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: """ Single timestep forward pass. Returns: spike_out: Output spikes from last layer (B, H_last) new_mems: Updated membrane potentials all_mems: Membrane potentials from all layers (for Lyapunov) """ new_mems = [] all_mems = [] h = x_t for i in range(self.num_layers): h = self.linears[i](h) spk, mem = self.lifs[i](h, mems[i]) new_mems.append(mem) all_mems.append(mem) h = spk if self.dropouts is not None and training: h = self.dropouts[i](h) return h, new_mems, all_mems def forward( self, x: torch.Tensor, compute_lyapunov: bool = False, lyap_eps: float = 1e-4, lyap_layers: Optional[List[int]] = None, record_states: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict[str, torch.Tensor]]]: """ Forward pass with optional Lyapunov exponent computation. Args: x: Input tensor (B, T, D) compute_lyapunov: Whether to compute Lyapunov exponent lyap_eps: Perturbation magnitude for Lyapunov computation lyap_layers: Which layers to measure (default: all). e.g., [0] for first layer only, [-1] for last layer record_states: Whether to record spikes and membrane potentials Returns: logits: Classification logits (B, num_classes) lyap_est: Estimated Lyapunov exponent (scalar) or None recordings: Dict with 'spikes' (B,T,H) and 'membrane' (B,T,H) or None """ B, T, D = x.shape device, dtype = x.device, x.dtype # Initialize states mems = self._init_states(B, device, dtype) spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) # Recording setup if record_states: spike_rec = [] mem_rec = [] # Lyapunov setup if compute_lyapunov: if lyap_layers is None: lyap_layers = list(range(self.num_layers)) # Perturbed trajectory - perturb all membrane potentials mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] lyap_accum = torch.zeros(B, device=device, dtype=dtype) # Time loop for t in range(T): x_t = x[:, t, :] # Nominal trajectory spk, mems, all_mems = self._step(x_t, mems, training=self.training) spike_sum = spike_sum + spk if record_states: spike_rec.append(spk.detach()) mem_rec.append(all_mems[-1].detach()) # Last layer membrane if compute_lyapunov: # Perturbed trajectory _, mems_p, all_mems_p = self._step(x_t, mems_p, training=False) # Compute divergence across selected layers delta_sq = torch.zeros(B, device=device, dtype=dtype) delta_p_sq = torch.zeros(B, device=device, dtype=dtype) for layer_idx in lyap_layers: diff = all_mems_p[layer_idx] - all_mems[layer_idx] delta_sq += (diff ** 2).sum(dim=1) delta = torch.sqrt(delta_sq + 1e-12) # Renormalization step (key for numerical stability) # Rescale perturbation back to fixed magnitude for layer_idx in lyap_layers: diff = mems_p[layer_idx] - mems[layer_idx] # Normalize to maintain fixed perturbation magnitude norm = torch.norm(diff.reshape(B, -1), dim=1, keepdim=True) + 1e-12 diff_normalized = diff / norm.unsqueeze(-1) if diff.ndim > 2 else diff / norm mems_p[layer_idx] = mems[layer_idx] + lyap_eps * diff_normalized # Accumulate log-divergence lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) logits = self.readout(spike_sum) if compute_lyapunov: # Average over time and batch lyap_est = (lyap_accum / T).mean() else: lyap_est = None if record_states: recordings = { "spikes": torch.stack(spike_rec, dim=1), # (B, T, H) "membrane": torch.stack(mem_rec, dim=1), # (B, T, H) } else: recordings = None return logits, lyap_est, recordings class RecurrentLyapunovSNN(nn.Module): """ Recurrent SNN with Lyapunov exponent computation. Uses snnTorch's RSynaptic (recurrent synaptic) neurons for richer temporal dynamics. Args: input_dim: Input feature dimension hidden_dims: List of hidden layer sizes num_classes: Number of output classes alpha: Synaptic current decay rate beta: Membrane potential decay rate threshold: Firing threshold """ def __init__( self, input_dim: int, hidden_dims: List[int], num_classes: int, alpha: float = 0.9, beta: float = 0.85, threshold: float = 1.0, spike_grad: Optional[Any] = None, ): super().__init__() if spike_grad is None: spike_grad = surrogate.fast_sigmoid(slope=25) self.hidden_dims = hidden_dims self.num_layers = len(hidden_dims) self.alpha = alpha self.beta = beta # Build layers with recurrent synaptic neurons self.linears = nn.ModuleList() self.neurons = nn.ModuleList() dims = [input_dim] + hidden_dims for i in range(self.num_layers): self.linears.append(nn.Linear(dims[i], dims[i + 1])) self.neurons.append( snn.RSynaptic( alpha=alpha, beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False, reset_mechanism="subtract", all_to_all=True, linear_features=dims[i + 1], ) ) self.readout = nn.Linear(hidden_dims[-1], num_classes) self._init_weights() def _init_weights(self): for lin in self.linears: nn.init.xavier_uniform_(lin.weight) nn.init.zeros_(lin.bias) nn.init.xavier_uniform_(self.readout.weight) nn.init.zeros_(self.readout.bias) def _init_states(self, batch_size: int, device, dtype): """Initialize synaptic currents and membrane potentials.""" syns = [] mems = [] for dim in self.hidden_dims: syns.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) return syns, mems def forward( self, x: torch.Tensor, compute_lyapunov: bool = False, lyap_eps: float = 1e-4, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Forward pass with optional Lyapunov computation.""" B, T, D = x.shape device, dtype = x.device, x.dtype syns, mems = self._init_states(B, device, dtype) spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) if compute_lyapunov: # Perturb both synaptic currents and membrane potentials syns_p = [s + lyap_eps * torch.randn_like(s) for s in syns] mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] lyap_accum = torch.zeros(B, device=device, dtype=dtype) for t in range(T): x_t = x[:, t, :] # Nominal trajectory h = x_t new_syns, new_mems = [], [] for i in range(self.num_layers): h = self.linears[i](h) spk, syn, mem = self.neurons[i](h, syns[i], mems[i]) new_syns.append(syn) new_mems.append(mem) h = spk syns, mems = new_syns, new_mems spike_sum = spike_sum + h if compute_lyapunov: # Perturbed trajectory h_p = x_t new_syns_p, new_mems_p = [], [] for i in range(self.num_layers): h_p = self.linears[i](h_p) spk_p, syn_p, mem_p = self.neurons[i](h_p, syns_p[i], mems_p[i]) new_syns_p.append(syn_p) new_mems_p.append(mem_p) h_p = spk_p # Compute divergence (on membrane potentials) delta_sq = torch.zeros(B, device=device, dtype=dtype) for i in range(self.num_layers): diff_m = new_mems_p[i] - new_mems[i] diff_s = new_syns_p[i] - new_syns[i] delta_sq += (diff_m ** 2).sum(dim=1) + (diff_s ** 2).sum(dim=1) delta = torch.sqrt(delta_sq + 1e-12) lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) # Renormalize perturbation total_dim = sum(2 * d for d in self.hidden_dims) # syn + mem scale = lyap_eps / (delta.unsqueeze(-1) + 1e-12) syns_p = [new_syns[i] + scale * (new_syns_p[i] - new_syns[i]) for i in range(self.num_layers)] mems_p = [new_mems[i] + scale * (new_mems_p[i] - new_mems[i]) for i in range(self.num_layers)] logits = self.readout(spike_sum) if compute_lyapunov: lyap_est = (lyap_accum / T).mean() else: lyap_est = None return logits, lyap_est def create_snn( model_type: str, input_dim: int, hidden_dims: List[int], num_classes: int, **kwargs, ) -> nn.Module: """ Factory function to create SNN models. Args: model_type: "feedforward" or "recurrent" input_dim: Input feature dimension hidden_dims: List of hidden layer sizes num_classes: Number of output classes **kwargs: Additional arguments passed to model constructor Returns: SNN model instance """ if model_type == "feedforward": return LyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs) elif model_type == "recurrent": return RecurrentLyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs) else: raise ValueError(f"Unknown model_type: {model_type}. Use 'feedforward' or 'recurrent'")