From cd99d6b874d9d09b3bb87b8485cc787885af71f1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 13 Jan 2026 23:49:05 -0600 Subject: init commit --- files/models/snn_snntorch.py | 398 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100644 files/models/snn_snntorch.py (limited to 'files/models/snn_snntorch.py') diff --git a/files/models/snn_snntorch.py b/files/models/snn_snntorch.py new file mode 100644 index 0000000..71c1c18 --- /dev/null +++ b/files/models/snn_snntorch.py @@ -0,0 +1,398 @@ +""" +snnTorch-based SNN with Lyapunov exponent regularization. + +This module provides deep SNN architectures using snnTorch with proper +finite-time Lyapunov exponent computation for training stabilization. +""" + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import snntorch as snn +from snntorch import surrogate + + +class LyapunovSNN(nn.Module): + """ + Multi-layer SNN using snnTorch with Lyapunov exponent computation. + + Architecture: + Input (B, T, D) -> [LIF layers] -> time-summed spikes -> Linear -> logits + + Args: + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes (e.g., [256, 128] for 2 layers) + num_classes: Number of output classes + beta: Membrane potential decay factor (0 < beta < 1) + threshold: Firing threshold + spike_grad: Surrogate gradient function (default: fast_sigmoid) + dropout: Dropout probability between layers (0 = no dropout) + """ + + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + beta: float = 0.9, + threshold: float = 1.0, + spike_grad: Optional[Any] = None, + dropout: float = 0.0, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.hidden_dims = hidden_dims + self.num_layers = len(hidden_dims) + self.beta = beta + self.threshold = threshold + + # Build layers + self.linears = nn.ModuleList() + self.lifs = nn.ModuleList() + self.dropouts = nn.ModuleList() if dropout > 0 else None + + dims = [input_dim] + hidden_dims + for i in range(self.num_layers): + self.linears.append(nn.Linear(dims[i], dims[i + 1])) + self.lifs.append( + snn.Leaky( + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + reset_mechanism="subtract", + ) + ) + if dropout > 0: + self.dropouts.append(nn.Dropout(p=dropout)) + + # Readout layer + self.readout = nn.Linear(hidden_dims[-1], num_classes) + + # Initialize weights + self._init_weights() + + def _init_weights(self): + for lin in self.linears: + nn.init.xavier_uniform_(lin.weight) + nn.init.zeros_(lin.bias) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + def _init_states(self, batch_size: int, device, dtype) -> List[torch.Tensor]: + """Initialize membrane potentials for all layers.""" + mems = [] + for dim in self.hidden_dims: + mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + return mems + + def _step( + self, + x_t: torch.Tensor, + mems: List[torch.Tensor], + training: bool = True, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + """ + Single timestep forward pass. + + Returns: + spike_out: Output spikes from last layer (B, H_last) + new_mems: Updated membrane potentials + all_mems: Membrane potentials from all layers (for Lyapunov) + """ + new_mems = [] + all_mems = [] + + h = x_t + for i in range(self.num_layers): + h = self.linears[i](h) + spk, mem = self.lifs[i](h, mems[i]) + new_mems.append(mem) + all_mems.append(mem) + h = spk + if self.dropouts is not None and training: + h = self.dropouts[i](h) + + return h, new_mems, all_mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + lyap_layers: Optional[List[int]] = None, + record_states: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict[str, torch.Tensor]]]: + """ + Forward pass with optional Lyapunov exponent computation. + + Args: + x: Input tensor (B, T, D) + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude for Lyapunov computation + lyap_layers: Which layers to measure (default: all). + e.g., [0] for first layer only, [-1] for last layer + record_states: Whether to record spikes and membrane potentials + + Returns: + logits: Classification logits (B, num_classes) + lyap_est: Estimated Lyapunov exponent (scalar) or None + recordings: Dict with 'spikes' (B,T,H) and 'membrane' (B,T,H) or None + """ + B, T, D = x.shape + device, dtype = x.device, x.dtype + + # Initialize states + mems = self._init_states(B, device, dtype) + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + # Recording setup + if record_states: + spike_rec = [] + mem_rec = [] + + # Lyapunov setup + if compute_lyapunov: + if lyap_layers is None: + lyap_layers = list(range(self.num_layers)) + + # Perturbed trajectory - perturb all membrane potentials + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + # Time loop + for t in range(T): + x_t = x[:, t, :] + + # Nominal trajectory + spk, mems, all_mems = self._step(x_t, mems, training=self.training) + spike_sum = spike_sum + spk + + if record_states: + spike_rec.append(spk.detach()) + mem_rec.append(all_mems[-1].detach()) # Last layer membrane + + if compute_lyapunov: + # Perturbed trajectory + _, mems_p, all_mems_p = self._step(x_t, mems_p, training=False) + + # Compute divergence across selected layers + delta_sq = torch.zeros(B, device=device, dtype=dtype) + delta_p_sq = torch.zeros(B, device=device, dtype=dtype) + + for layer_idx in lyap_layers: + diff = all_mems_p[layer_idx] - all_mems[layer_idx] + delta_sq += (diff ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + + # Renormalization step (key for numerical stability) + # Rescale perturbation back to fixed magnitude + for layer_idx in lyap_layers: + diff = mems_p[layer_idx] - mems[layer_idx] + # Normalize to maintain fixed perturbation magnitude + norm = torch.norm(diff.reshape(B, -1), dim=1, keepdim=True) + 1e-12 + diff_normalized = diff / norm.unsqueeze(-1) if diff.ndim > 2 else diff / norm + mems_p[layer_idx] = mems[layer_idx] + lyap_eps * diff_normalized + + # Accumulate log-divergence + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + logits = self.readout(spike_sum) + + if compute_lyapunov: + # Average over time and batch + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + if record_states: + recordings = { + "spikes": torch.stack(spike_rec, dim=1), # (B, T, H) + "membrane": torch.stack(mem_rec, dim=1), # (B, T, H) + } + else: + recordings = None + + return logits, lyap_est, recordings + + +class RecurrentLyapunovSNN(nn.Module): + """ + Recurrent SNN with Lyapunov exponent computation. + + Uses snnTorch's RSynaptic (recurrent synaptic) neurons for + richer temporal dynamics. + + Args: + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes + num_classes: Number of output classes + alpha: Synaptic current decay rate + beta: Membrane potential decay rate + threshold: Firing threshold + """ + + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + alpha: float = 0.9, + beta: float = 0.85, + threshold: float = 1.0, + spike_grad: Optional[Any] = None, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.hidden_dims = hidden_dims + self.num_layers = len(hidden_dims) + self.alpha = alpha + self.beta = beta + + # Build layers with recurrent synaptic neurons + self.linears = nn.ModuleList() + self.neurons = nn.ModuleList() + + dims = [input_dim] + hidden_dims + for i in range(self.num_layers): + self.linears.append(nn.Linear(dims[i], dims[i + 1])) + self.neurons.append( + snn.RSynaptic( + alpha=alpha, + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + reset_mechanism="subtract", + all_to_all=True, + linear_features=dims[i + 1], + ) + ) + + self.readout = nn.Linear(hidden_dims[-1], num_classes) + self._init_weights() + + def _init_weights(self): + for lin in self.linears: + nn.init.xavier_uniform_(lin.weight) + nn.init.zeros_(lin.bias) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + def _init_states(self, batch_size: int, device, dtype): + """Initialize synaptic currents and membrane potentials.""" + syns = [] + mems = [] + for dim in self.hidden_dims: + syns.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype)) + return syns, mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass with optional Lyapunov computation.""" + B, T, D = x.shape + device, dtype = x.device, x.dtype + + syns, mems = self._init_states(B, device, dtype) + spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype) + + if compute_lyapunov: + # Perturb both synaptic currents and membrane potentials + syns_p = [s + lyap_eps * torch.randn_like(s) for s in syns] + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + for t in range(T): + x_t = x[:, t, :] + + # Nominal trajectory + h = x_t + new_syns, new_mems = [], [] + for i in range(self.num_layers): + h = self.linears[i](h) + spk, syn, mem = self.neurons[i](h, syns[i], mems[i]) + new_syns.append(syn) + new_mems.append(mem) + h = spk + syns, mems = new_syns, new_mems + spike_sum = spike_sum + h + + if compute_lyapunov: + # Perturbed trajectory + h_p = x_t + new_syns_p, new_mems_p = [], [] + for i in range(self.num_layers): + h_p = self.linears[i](h_p) + spk_p, syn_p, mem_p = self.neurons[i](h_p, syns_p[i], mems_p[i]) + new_syns_p.append(syn_p) + new_mems_p.append(mem_p) + h_p = spk_p + + # Compute divergence (on membrane potentials) + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(self.num_layers): + diff_m = new_mems_p[i] - new_mems[i] + diff_s = new_syns_p[i] - new_syns[i] + delta_sq += (diff_m ** 2).sum(dim=1) + (diff_s ** 2).sum(dim=1) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize perturbation + total_dim = sum(2 * d for d in self.hidden_dims) # syn + mem + scale = lyap_eps / (delta.unsqueeze(-1) + 1e-12) + + syns_p = [new_syns[i] + scale * (new_syns_p[i] - new_syns[i]) + for i in range(self.num_layers)] + mems_p = [new_mems[i] + scale * (new_mems_p[i] - new_mems[i]) + for i in range(self.num_layers)] + + logits = self.readout(spike_sum) + + if compute_lyapunov: + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + return logits, lyap_est + + +def create_snn( + model_type: str, + input_dim: int, + hidden_dims: List[int], + num_classes: int, + **kwargs, +) -> nn.Module: + """ + Factory function to create SNN models. + + Args: + model_type: "feedforward" or "recurrent" + input_dim: Input feature dimension + hidden_dims: List of hidden layer sizes + num_classes: Number of output classes + **kwargs: Additional arguments passed to model constructor + + Returns: + SNN model instance + """ + if model_type == "feedforward": + return LyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs) + elif model_type == "recurrent": + return RecurrentLyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs) + else: + raise ValueError(f"Unknown model_type: {model_type}. Use 'feedforward' or 'recurrent'") -- cgit v1.2.3