summaryrefslogtreecommitdiff
path: root/files/models/snn_snntorch.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
commitcd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch)
tree59a233959932ca0e4f12f196275e07fcf443b33f /files/models/snn_snntorch.py
init commit
Diffstat (limited to 'files/models/snn_snntorch.py')
-rw-r--r--files/models/snn_snntorch.py398
1 files changed, 398 insertions, 0 deletions
diff --git a/files/models/snn_snntorch.py b/files/models/snn_snntorch.py
new file mode 100644
index 0000000..71c1c18
--- /dev/null
+++ b/files/models/snn_snntorch.py
@@ -0,0 +1,398 @@
+"""
+snnTorch-based SNN with Lyapunov exponent regularization.
+
+This module provides deep SNN architectures using snnTorch with proper
+finite-time Lyapunov exponent computation for training stabilization.
+"""
+
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import snntorch as snn
+from snntorch import surrogate
+
+
+class LyapunovSNN(nn.Module):
+ """
+ Multi-layer SNN using snnTorch with Lyapunov exponent computation.
+
+ Architecture:
+ Input (B, T, D) -> [LIF layers] -> time-summed spikes -> Linear -> logits
+
+ Args:
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes (e.g., [256, 128] for 2 layers)
+ num_classes: Number of output classes
+ beta: Membrane potential decay factor (0 < beta < 1)
+ threshold: Firing threshold
+ spike_grad: Surrogate gradient function (default: fast_sigmoid)
+ dropout: Dropout probability between layers (0 = no dropout)
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ spike_grad: Optional[Any] = None,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.hidden_dims = hidden_dims
+ self.num_layers = len(hidden_dims)
+ self.beta = beta
+ self.threshold = threshold
+
+ # Build layers
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+ self.dropouts = nn.ModuleList() if dropout > 0 else None
+
+ dims = [input_dim] + hidden_dims
+ for i in range(self.num_layers):
+ self.linears.append(nn.Linear(dims[i], dims[i + 1]))
+ self.lifs.append(
+ snn.Leaky(
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ reset_mechanism="subtract",
+ )
+ )
+ if dropout > 0:
+ self.dropouts.append(nn.Dropout(p=dropout))
+
+ # Readout layer
+ self.readout = nn.Linear(hidden_dims[-1], num_classes)
+
+ # Initialize weights
+ self._init_weights()
+
+ def _init_weights(self):
+ for lin in self.linears:
+ nn.init.xavier_uniform_(lin.weight)
+ nn.init.zeros_(lin.bias)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ def _init_states(self, batch_size: int, device, dtype) -> List[torch.Tensor]:
+ """Initialize membrane potentials for all layers."""
+ mems = []
+ for dim in self.hidden_dims:
+ mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ return mems
+
+ def _step(
+ self,
+ x_t: torch.Tensor,
+ mems: List[torch.Tensor],
+ training: bool = True,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
+ """
+ Single timestep forward pass.
+
+ Returns:
+ spike_out: Output spikes from last layer (B, H_last)
+ new_mems: Updated membrane potentials
+ all_mems: Membrane potentials from all layers (for Lyapunov)
+ """
+ new_mems = []
+ all_mems = []
+
+ h = x_t
+ for i in range(self.num_layers):
+ h = self.linears[i](h)
+ spk, mem = self.lifs[i](h, mems[i])
+ new_mems.append(mem)
+ all_mems.append(mem)
+ h = spk
+ if self.dropouts is not None and training:
+ h = self.dropouts[i](h)
+
+ return h, new_mems, all_mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ lyap_layers: Optional[List[int]] = None,
+ record_states: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict[str, torch.Tensor]]]:
+ """
+ Forward pass with optional Lyapunov exponent computation.
+
+ Args:
+ x: Input tensor (B, T, D)
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude for Lyapunov computation
+ lyap_layers: Which layers to measure (default: all).
+ e.g., [0] for first layer only, [-1] for last layer
+ record_states: Whether to record spikes and membrane potentials
+
+ Returns:
+ logits: Classification logits (B, num_classes)
+ lyap_est: Estimated Lyapunov exponent (scalar) or None
+ recordings: Dict with 'spikes' (B,T,H) and 'membrane' (B,T,H) or None
+ """
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize states
+ mems = self._init_states(B, device, dtype)
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ # Recording setup
+ if record_states:
+ spike_rec = []
+ mem_rec = []
+
+ # Lyapunov setup
+ if compute_lyapunov:
+ if lyap_layers is None:
+ lyap_layers = list(range(self.num_layers))
+
+ # Perturbed trajectory - perturb all membrane potentials
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ # Time loop
+ for t in range(T):
+ x_t = x[:, t, :]
+
+ # Nominal trajectory
+ spk, mems, all_mems = self._step(x_t, mems, training=self.training)
+ spike_sum = spike_sum + spk
+
+ if record_states:
+ spike_rec.append(spk.detach())
+ mem_rec.append(all_mems[-1].detach()) # Last layer membrane
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ _, mems_p, all_mems_p = self._step(x_t, mems_p, training=False)
+
+ # Compute divergence across selected layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ delta_p_sq = torch.zeros(B, device=device, dtype=dtype)
+
+ for layer_idx in lyap_layers:
+ diff = all_mems_p[layer_idx] - all_mems[layer_idx]
+ delta_sq += (diff ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+
+ # Renormalization step (key for numerical stability)
+ # Rescale perturbation back to fixed magnitude
+ for layer_idx in lyap_layers:
+ diff = mems_p[layer_idx] - mems[layer_idx]
+ # Normalize to maintain fixed perturbation magnitude
+ norm = torch.norm(diff.reshape(B, -1), dim=1, keepdim=True) + 1e-12
+ diff_normalized = diff / norm.unsqueeze(-1) if diff.ndim > 2 else diff / norm
+ mems_p[layer_idx] = mems[layer_idx] + lyap_eps * diff_normalized
+
+ # Accumulate log-divergence
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ logits = self.readout(spike_sum)
+
+ if compute_lyapunov:
+ # Average over time and batch
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ if record_states:
+ recordings = {
+ "spikes": torch.stack(spike_rec, dim=1), # (B, T, H)
+ "membrane": torch.stack(mem_rec, dim=1), # (B, T, H)
+ }
+ else:
+ recordings = None
+
+ return logits, lyap_est, recordings
+
+
+class RecurrentLyapunovSNN(nn.Module):
+ """
+ Recurrent SNN with Lyapunov exponent computation.
+
+ Uses snnTorch's RSynaptic (recurrent synaptic) neurons for
+ richer temporal dynamics.
+
+ Args:
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes
+ num_classes: Number of output classes
+ alpha: Synaptic current decay rate
+ beta: Membrane potential decay rate
+ threshold: Firing threshold
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ alpha: float = 0.9,
+ beta: float = 0.85,
+ threshold: float = 1.0,
+ spike_grad: Optional[Any] = None,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.hidden_dims = hidden_dims
+ self.num_layers = len(hidden_dims)
+ self.alpha = alpha
+ self.beta = beta
+
+ # Build layers with recurrent synaptic neurons
+ self.linears = nn.ModuleList()
+ self.neurons = nn.ModuleList()
+
+ dims = [input_dim] + hidden_dims
+ for i in range(self.num_layers):
+ self.linears.append(nn.Linear(dims[i], dims[i + 1]))
+ self.neurons.append(
+ snn.RSynaptic(
+ alpha=alpha,
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ reset_mechanism="subtract",
+ all_to_all=True,
+ linear_features=dims[i + 1],
+ )
+ )
+
+ self.readout = nn.Linear(hidden_dims[-1], num_classes)
+ self._init_weights()
+
+ def _init_weights(self):
+ for lin in self.linears:
+ nn.init.xavier_uniform_(lin.weight)
+ nn.init.zeros_(lin.bias)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ def _init_states(self, batch_size: int, device, dtype):
+ """Initialize synaptic currents and membrane potentials."""
+ syns = []
+ mems = []
+ for dim in self.hidden_dims:
+ syns.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ return syns, mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Forward pass with optional Lyapunov computation."""
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+
+ syns, mems = self._init_states(B, device, dtype)
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ # Perturb both synaptic currents and membrane potentials
+ syns_p = [s + lyap_eps * torch.randn_like(s) for s in syns]
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ for t in range(T):
+ x_t = x[:, t, :]
+
+ # Nominal trajectory
+ h = x_t
+ new_syns, new_mems = [], []
+ for i in range(self.num_layers):
+ h = self.linears[i](h)
+ spk, syn, mem = self.neurons[i](h, syns[i], mems[i])
+ new_syns.append(syn)
+ new_mems.append(mem)
+ h = spk
+ syns, mems = new_syns, new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ h_p = x_t
+ new_syns_p, new_mems_p = [], []
+ for i in range(self.num_layers):
+ h_p = self.linears[i](h_p)
+ spk_p, syn_p, mem_p = self.neurons[i](h_p, syns_p[i], mems_p[i])
+ new_syns_p.append(syn_p)
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # Compute divergence (on membrane potentials)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(self.num_layers):
+ diff_m = new_mems_p[i] - new_mems[i]
+ diff_s = new_syns_p[i] - new_syns[i]
+ delta_sq += (diff_m ** 2).sum(dim=1) + (diff_s ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize perturbation
+ total_dim = sum(2 * d for d in self.hidden_dims) # syn + mem
+ scale = lyap_eps / (delta.unsqueeze(-1) + 1e-12)
+
+ syns_p = [new_syns[i] + scale * (new_syns_p[i] - new_syns[i])
+ for i in range(self.num_layers)]
+ mems_p = [new_mems[i] + scale * (new_mems_p[i] - new_mems[i])
+ for i in range(self.num_layers)]
+
+ logits = self.readout(spike_sum)
+
+ if compute_lyapunov:
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ return logits, lyap_est
+
+
+def create_snn(
+ model_type: str,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ **kwargs,
+) -> nn.Module:
+ """
+ Factory function to create SNN models.
+
+ Args:
+ model_type: "feedforward" or "recurrent"
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes
+ num_classes: Number of output classes
+ **kwargs: Additional arguments passed to model constructor
+
+ Returns:
+ SNN model instance
+ """
+ if model_type == "feedforward":
+ return LyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs)
+ elif model_type == "recurrent":
+ return RecurrentLyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs)
+ else:
+ raise ValueError(f"Unknown model_type: {model_type}. Use 'feedforward' or 'recurrent'")