summaryrefslogtreecommitdiff
path: root/files/models
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
commitcd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch)
tree59a233959932ca0e4f12f196275e07fcf443b33f /files/models
init commit
Diffstat (limited to 'files/models')
-rw-r--r--files/models/conv_snn.py483
-rw-r--r--files/models/snn.py141
-rw-r--r--files/models/snn_snntorch.py398
3 files changed, 1022 insertions, 0 deletions
diff --git a/files/models/conv_snn.py b/files/models/conv_snn.py
new file mode 100644
index 0000000..69f77d6
--- /dev/null
+++ b/files/models/conv_snn.py
@@ -0,0 +1,483 @@
+"""
+Convolutional SNN with Lyapunov regularization for image classification.
+
+Properly handles spatial structure:
+- Input: (B, C, H, W) static image OR (B, T, C, H, W) spike tensor
+- Uses Conv-LIF layers to preserve spatial hierarchy
+- Rate encoding converts images to spike trains
+
+Based on standard SNN vision practices:
+- Rate/Poisson encoding for input
+- Conv → BatchNorm → LIF → Pool architecture
+- Time comes from encoding + LIF dynamics, not flattening
+"""
+
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import snntorch as snn
+from snntorch import surrogate
+
+
+class RateEncoder(nn.Module):
+ """
+ Rate (Poisson/Bernoulli) encoder for static images.
+
+ Converts intensity x ∈ [0,1] to spike probability per timestep.
+ Each pixel independently fires with P(spike) = x * gain.
+
+ Args:
+ T: Number of timesteps
+ gain: Scaling factor for firing probability (default 1.0)
+
+ Input: (B, C, H, W) normalized image in [0, 1]
+ Output: (B, T, C, H, W) binary spike tensor
+ """
+
+ def __init__(self, T: int = 25, gain: float = 1.0):
+ super().__init__()
+ self.T = T
+ self.gain = gain
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: (B, C, H, W) image tensor, values in [0, 1]
+ Returns:
+ spikes: (B, T, C, H, W) binary spike tensor
+ """
+ # Clamp to valid probability range
+ prob = (x * self.gain).clamp(0, 1)
+
+ # Expand for T timesteps: (B, C, H, W) -> (B, T, C, H, W)
+ prob = prob.unsqueeze(1).expand(-1, self.T, -1, -1, -1)
+
+ # Sample spikes
+ spikes = torch.bernoulli(prob)
+
+ return spikes
+
+
+class DirectEncoder(nn.Module):
+ """
+ Direct encoding - feed static image as constant current.
+
+ Common in surrogate gradient papers: no spike encoding at input,
+ let spiking emerge from the network dynamics.
+
+ Input: (B, C, H, W) image
+ Output: (B, T, C, H, W) repeated image (as analog current)
+ """
+
+ def __init__(self, T: int = 25):
+ super().__init__()
+ self.T = T
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Simply repeat across time
+ return x.unsqueeze(1).expand(-1, self.T, -1, -1, -1)
+
+
+class ConvLIFBlock(nn.Module):
+ """
+ Conv → BatchNorm → LIF block.
+
+ Maintains spatial structure while adding spiking dynamics.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 1,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ spike_grad=None,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(
+ in_channels, out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.lif = snn.Leaky(
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mem: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x: (B, C_in, H, W) input (spikes or current)
+ mem: (B, C_out, H', W') membrane potential
+ Returns:
+ spk: (B, C_out, H', W') output spikes
+ mem: (B, C_out, H', W') updated membrane
+ """
+ cur = self.bn(self.conv(x))
+ spk, mem = self.lif(cur, mem)
+ return spk, mem
+
+
+class ConvLyapunovSNN(nn.Module):
+ """
+ Convolutional SNN with Lyapunov exponent regularization.
+
+ Architecture for CIFAR-10 (32x32x3):
+ Input → Encoder → [Conv-LIF-Pool] × N → FC → Output
+
+ Properly preserves spatial structure for hierarchical feature learning.
+
+ Args:
+ in_channels: Input channels (3 for RGB)
+ num_classes: Output classes
+ channels: List of channel sizes for conv layers
+ T: Number of timesteps
+ beta: LIF membrane decay
+ threshold: LIF firing threshold
+ encoding: 'rate', 'direct', or 'none' (pre-encoded input)
+ encoding_gain: Gain for rate encoding
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ channels: List[int] = [64, 128, 256],
+ T: int = 25,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ encoding: str = 'rate',
+ encoding_gain: float = 1.0,
+ dropout: float = 0.2,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.encoding_type = encoding
+ self.channels = channels
+ self.num_layers = len(channels)
+
+ # Input encoder
+ if encoding == 'rate':
+ self.encoder = RateEncoder(T=T, gain=encoding_gain)
+ elif encoding == 'direct':
+ self.encoder = DirectEncoder(T=T)
+ else:
+ self.encoder = None # Expect pre-encoded (B, T, C, H, W) input
+
+ # Build conv-LIF layers
+ self.blocks = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ ch_in = in_channels
+ for ch_out in channels:
+ self.blocks.append(
+ ConvLIFBlock(ch_in, ch_out, beta=beta, threshold=threshold)
+ )
+ self.pools.append(nn.AvgPool2d(2))
+ ch_in = ch_out
+
+ # Calculate output spatial size after pooling
+ # CIFAR: 32 -> 16 -> 8 -> 4 (for 3 layers)
+ spatial_size = 32 // (2 ** len(channels))
+ fc_input = channels[-1] * spatial_size * spatial_size
+
+ # Fully connected readout
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(fc_input, num_classes)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_mem(self, batch_size: int, device, dtype) -> List[torch.Tensor]:
+ """Initialize membrane potentials for all layers."""
+ mems = []
+ H, W = 32, 32
+ for i, ch in enumerate(self.channels):
+ H, W = H // 2, W // 2 # After pooling
+ # Actually we need size BEFORE pooling for LIF
+ H_pre, W_pre = H * 2, W * 2
+ mems.append(torch.zeros(batch_size, ch, H_pre, W_pre, device=device, dtype=dtype))
+ return mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+ """
+ Forward pass with optional Lyapunov computation.
+
+ Args:
+ x: Input tensor
+ - If encoder: (B, C, H, W) static image
+ - If no encoder: (B, T, C, H, W) pre-encoded spikes
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude
+
+ Returns:
+ logits: (B, num_classes)
+ lyap_est: Scalar Lyapunov estimate or None
+ recordings: Optional dict with spike recordings
+ """
+ # Encode input if needed
+ if self.encoder is not None:
+ x = self.encoder(x) # (B, C, H, W) -> (B, T, C, H, W)
+
+ B, T, C, H, W = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize membrane potentials
+ mems = self._init_mem(B, device, dtype)
+
+ # For accumulating output spikes
+ spike_sum = None
+
+ # Lyapunov setup
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ # Time loop
+ for t in range(T):
+ x_t = x[:, t] # (B, C, H, W)
+
+ # Forward through conv-LIF blocks
+ h = x_t
+ new_mems = []
+ for i, (block, pool) in enumerate(zip(self.blocks, self.pools)):
+ h, mem = block(h, mems[i])
+ new_mems.append(mem)
+ h = pool(h) # Spatial downsampling
+
+ mems = new_mems
+
+ # Accumulate final layer spikes
+ if spike_sum is None:
+ spike_sum = h.view(B, -1)
+ else:
+ spike_sum = spike_sum + h.view(B, -1)
+
+ # Lyapunov computation
+ if compute_lyapunov:
+ h_p = x_t
+ new_mems_p = []
+ for i, (block, pool) in enumerate(zip(self.blocks, self.pools)):
+ h_p, mem_p = block(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = pool(h_p)
+
+ # Compute divergence
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(self.num_layers):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq += (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize perturbation
+ for i in range(self.num_layers):
+ diff = new_mems_p[i] - new_mems[i]
+ norm = torch.sqrt((diff ** 2).sum(dim=(1, 2, 3), keepdim=True) + 1e-12)
+ # Broadcast norm to spatial dimensions
+ norm = norm.view(B, 1, 1, 1)
+ new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm
+
+ mems_p = new_mems_p
+
+ # Readout
+ out = self.dropout(spike_sum)
+ logits = self.fc(out)
+
+ if compute_lyapunov:
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ return logits, lyap_est, None
+
+
+class VGGLyapunovSNN(nn.Module):
+ """
+ VGG-style deep Conv-SNN with Lyapunov regularization.
+
+ Deeper architecture for more challenging benchmarks.
+ Uses multiple conv layers between pooling to increase depth.
+
+ Architecture (VGG-9 style):
+ [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → FC
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ T: int = 25,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ encoding: str = 'rate',
+ dropout: float = 0.3,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.encoding_type = encoding
+
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ if encoding == 'rate':
+ self.encoder = RateEncoder(T=T)
+ elif encoding == 'direct':
+ self.encoder = DirectEncoder(T=T)
+ else:
+ self.encoder = None
+
+ # VGG-style blocks: (in_ch, out_ch, num_convs)
+ block_configs = [
+ (in_channels, 64, 2), # 32x32 -> 16x16
+ (64, 128, 2), # 16x16 -> 8x8
+ (128, 256, 2), # 8x8 -> 4x4
+ ]
+
+ self.blocks = nn.ModuleList()
+ for in_ch, out_ch, n_convs in block_configs:
+ layers = []
+ for i in range(n_convs):
+ ch_in = in_ch if i == 0 else out_ch
+ layers.append(nn.Conv2d(ch_in, out_ch, 3, padding=1, bias=False))
+ layers.append(nn.BatchNorm2d(out_ch))
+ self.blocks.append(nn.ModuleList(layers))
+
+ # LIF neurons for each conv layer
+ self.lifs = nn.ModuleList([
+ snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+ for _ in range(6) # 2 convs × 3 blocks
+ ])
+
+ self.pools = nn.ModuleList([nn.AvgPool2d(2) for _ in range(3)])
+
+ # FC layers
+ self.fc1 = nn.Linear(256 * 4 * 4, 512)
+ self.lif_fc = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+ self.dropout = nn.Dropout(dropout)
+ self.fc2 = nn.Linear(512, num_classes)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+
+ if self.encoder is not None:
+ x = self.encoder(x)
+
+ B, T, C, H, W = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize all membrane potentials
+ # For each conv layer output
+ mem_shapes = [
+ (B, 64, 32, 32), (B, 64, 32, 32), # Block 1
+ (B, 128, 16, 16), (B, 128, 16, 16), # Block 2
+ (B, 256, 8, 8), (B, 256, 8, 8), # Block 3
+ (B, 512), # FC
+ ]
+ mems = [torch.zeros(s, device=device, dtype=dtype) for s in mem_shapes]
+
+ spike_sum = torch.zeros(B, 512, device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ for t in range(T):
+ h = x[:, t]
+
+ lif_idx = 0
+ for block_idx, (block_layers, pool) in enumerate(zip(self.blocks, self.pools)):
+ for i in range(0, len(block_layers), 2): # Conv, BN pairs
+ conv, bn = block_layers[i], block_layers[i + 1]
+ h = bn(conv(h))
+ h, mems[lif_idx] = self.lifs[lif_idx](h, mems[lif_idx])
+ lif_idx += 1
+ h = pool(h)
+
+ # FC layers
+ h = h.view(B, -1)
+ h = self.fc1(h)
+ h, mems[6] = self.lif_fc(h, mems[6])
+ spike_sum = spike_sum + h
+
+ # Lyapunov (simplified - just on last layer)
+ if compute_lyapunov:
+ diff = mems[6] - mems_p[6] if t > 0 else torch.zeros_like(mems[6])
+ delta = torch.norm(diff.view(B, -1), dim=1) + 1e-12
+ if t > 0:
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+ mems_p[6] = mems[6] + lyap_eps * torch.randn_like(mems[6])
+
+ out = self.dropout(spike_sum)
+ logits = self.fc2(out)
+
+ lyap_est = (lyap_accum / T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est, None
+
+
+def create_conv_snn(
+ model_type: str = 'simple',
+ **kwargs,
+) -> nn.Module:
+ """
+ Factory function for Conv-SNN models.
+
+ Args:
+ model_type: 'simple' (3-layer) or 'vgg' (6-layer VGG-style)
+ **kwargs: Arguments passed to model constructor
+ """
+ if model_type == 'simple':
+ return ConvLyapunovSNN(**kwargs)
+ elif model_type == 'vgg':
+ return VGGLyapunovSNN(**kwargs)
+ else:
+ raise ValueError(f"Unknown model_type: {model_type}")
diff --git a/files/models/snn.py b/files/models/snn.py
new file mode 100644
index 0000000..b1cf633
--- /dev/null
+++ b/files/models/snn.py
@@ -0,0 +1,141 @@
+import math
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SurrogateStep(torch.autograd.Function):
+ """
+ Heaviside with a smooth surrogate gradient (fast sigmoid).
+ """
+ @staticmethod
+ def forward(ctx, x: torch.Tensor, alpha: float):
+ ctx.save_for_backward(x)
+ ctx.alpha = alpha
+ return (x > 0).to(x.dtype)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ (x,) = ctx.saved_tensors
+ alpha = ctx.alpha
+ # d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid)
+ # Use fast sigmoid: s = 1 / (1 + |alpha*x|)
+ s = 1.0 / (1.0 + (alpha * x).abs())
+ grad = grad_output * s * s
+ return grad, None
+
+
+def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor:
+ return SurrogateStep.apply(x, alpha)
+
+
+class LIFLayer(nn.Module):
+ """
+ Single LIF layer without recurrent synapses between neurons.
+ Dynamics per neuron i:
+ v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1}
+ s_t = H( v_t - v_th ) with surrogate gradient
+ """
+ def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
+ super().__init__()
+ self.linear = nn.Linear(input_dim, hidden_dim, bias=True)
+ self.v_threshold = float(v_threshold)
+ self.decay = float(decay)
+ self.spike_alpha = float(spike_alpha)
+ self.rec_strength = float(rec_strength)
+ self.rec = None
+ if self.rec_strength != 0.0:
+ self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False)
+
+ nn.init.xavier_uniform_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+ if self.rec is not None:
+ nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale)
+
+ def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ # x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H)
+ I_t = self.linear(x_t) # (B, H)
+ R_t = 0.0
+ if self.rec is not None:
+ R_t = self.rec_strength * self.rec(s_prev)
+ v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev
+ s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha)
+ return v_t, s_t
+
+
+class SimpleSNN(nn.Module):
+ """
+ Minimal SNN for SHD-like input (B,T,D):
+ - One LIF hidden layer
+ - Readout linear on time-summed spikes
+ """
+ def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
+ super().__init__()
+ self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale)
+ self.readout = nn.Linear(hidden_dim, num_classes)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ @torch.no_grad()
+ def _init_states(self, batch_size: int, hidden_dim: int, device, dtype):
+ v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
+ s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
+ return v0, s0
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-3,
+ lyap_safe_eps: float = 1e-8,
+ lyap_measure: str = "v", # "v" or "s"
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ x: (B, T, D)
+ Returns:
+ logits: (B, C)
+ lyap_est: scalar tensor if compute_lyapunov else None
+ """
+ assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}"
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+ H = self.readout.in_features
+
+ v, s = self._init_states(B, H, device, dtype)
+ spike_sum = torch.zeros(B, H, device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ v_p = v + lyap_eps * torch.randn_like(v)
+ s_p = s.clone()
+ delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
+ lyap_terms = []
+
+ for t in range(T):
+ x_t = x[:, t, :]
+ v, s = self.lif(x_t, v, s)
+ spike_sum = spike_sum + s
+
+ if compute_lyapunov:
+ # run perturbed trajectory through same ops
+ v_p, s_p = self.lif(x_t, v_p, s_p)
+ if lyap_measure == "s":
+ delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps
+ else:
+ delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
+ ratio = delta_t / delta_prev
+ lyap_terms.append(torch.log(ratio + lyap_safe_eps))
+ delta_prev = delta_t
+
+ logits = self.readout(spike_sum) # (B, C)
+
+ if compute_lyapunov:
+ lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0) # (B,)
+ lyap_est = lyap_batch.mean() # scalar
+ else:
+ lyap_est = None
+
+ return logits, lyap_est
+
+
diff --git a/files/models/snn_snntorch.py b/files/models/snn_snntorch.py
new file mode 100644
index 0000000..71c1c18
--- /dev/null
+++ b/files/models/snn_snntorch.py
@@ -0,0 +1,398 @@
+"""
+snnTorch-based SNN with Lyapunov exponent regularization.
+
+This module provides deep SNN architectures using snnTorch with proper
+finite-time Lyapunov exponent computation for training stabilization.
+"""
+
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import snntorch as snn
+from snntorch import surrogate
+
+
+class LyapunovSNN(nn.Module):
+ """
+ Multi-layer SNN using snnTorch with Lyapunov exponent computation.
+
+ Architecture:
+ Input (B, T, D) -> [LIF layers] -> time-summed spikes -> Linear -> logits
+
+ Args:
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes (e.g., [256, 128] for 2 layers)
+ num_classes: Number of output classes
+ beta: Membrane potential decay factor (0 < beta < 1)
+ threshold: Firing threshold
+ spike_grad: Surrogate gradient function (default: fast_sigmoid)
+ dropout: Dropout probability between layers (0 = no dropout)
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ spike_grad: Optional[Any] = None,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.hidden_dims = hidden_dims
+ self.num_layers = len(hidden_dims)
+ self.beta = beta
+ self.threshold = threshold
+
+ # Build layers
+ self.linears = nn.ModuleList()
+ self.lifs = nn.ModuleList()
+ self.dropouts = nn.ModuleList() if dropout > 0 else None
+
+ dims = [input_dim] + hidden_dims
+ for i in range(self.num_layers):
+ self.linears.append(nn.Linear(dims[i], dims[i + 1]))
+ self.lifs.append(
+ snn.Leaky(
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ reset_mechanism="subtract",
+ )
+ )
+ if dropout > 0:
+ self.dropouts.append(nn.Dropout(p=dropout))
+
+ # Readout layer
+ self.readout = nn.Linear(hidden_dims[-1], num_classes)
+
+ # Initialize weights
+ self._init_weights()
+
+ def _init_weights(self):
+ for lin in self.linears:
+ nn.init.xavier_uniform_(lin.weight)
+ nn.init.zeros_(lin.bias)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ def _init_states(self, batch_size: int, device, dtype) -> List[torch.Tensor]:
+ """Initialize membrane potentials for all layers."""
+ mems = []
+ for dim in self.hidden_dims:
+ mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ return mems
+
+ def _step(
+ self,
+ x_t: torch.Tensor,
+ mems: List[torch.Tensor],
+ training: bool = True,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
+ """
+ Single timestep forward pass.
+
+ Returns:
+ spike_out: Output spikes from last layer (B, H_last)
+ new_mems: Updated membrane potentials
+ all_mems: Membrane potentials from all layers (for Lyapunov)
+ """
+ new_mems = []
+ all_mems = []
+
+ h = x_t
+ for i in range(self.num_layers):
+ h = self.linears[i](h)
+ spk, mem = self.lifs[i](h, mems[i])
+ new_mems.append(mem)
+ all_mems.append(mem)
+ h = spk
+ if self.dropouts is not None and training:
+ h = self.dropouts[i](h)
+
+ return h, new_mems, all_mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ lyap_layers: Optional[List[int]] = None,
+ record_states: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict[str, torch.Tensor]]]:
+ """
+ Forward pass with optional Lyapunov exponent computation.
+
+ Args:
+ x: Input tensor (B, T, D)
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude for Lyapunov computation
+ lyap_layers: Which layers to measure (default: all).
+ e.g., [0] for first layer only, [-1] for last layer
+ record_states: Whether to record spikes and membrane potentials
+
+ Returns:
+ logits: Classification logits (B, num_classes)
+ lyap_est: Estimated Lyapunov exponent (scalar) or None
+ recordings: Dict with 'spikes' (B,T,H) and 'membrane' (B,T,H) or None
+ """
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize states
+ mems = self._init_states(B, device, dtype)
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ # Recording setup
+ if record_states:
+ spike_rec = []
+ mem_rec = []
+
+ # Lyapunov setup
+ if compute_lyapunov:
+ if lyap_layers is None:
+ lyap_layers = list(range(self.num_layers))
+
+ # Perturbed trajectory - perturb all membrane potentials
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ # Time loop
+ for t in range(T):
+ x_t = x[:, t, :]
+
+ # Nominal trajectory
+ spk, mems, all_mems = self._step(x_t, mems, training=self.training)
+ spike_sum = spike_sum + spk
+
+ if record_states:
+ spike_rec.append(spk.detach())
+ mem_rec.append(all_mems[-1].detach()) # Last layer membrane
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ _, mems_p, all_mems_p = self._step(x_t, mems_p, training=False)
+
+ # Compute divergence across selected layers
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ delta_p_sq = torch.zeros(B, device=device, dtype=dtype)
+
+ for layer_idx in lyap_layers:
+ diff = all_mems_p[layer_idx] - all_mems[layer_idx]
+ delta_sq += (diff ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+
+ # Renormalization step (key for numerical stability)
+ # Rescale perturbation back to fixed magnitude
+ for layer_idx in lyap_layers:
+ diff = mems_p[layer_idx] - mems[layer_idx]
+ # Normalize to maintain fixed perturbation magnitude
+ norm = torch.norm(diff.reshape(B, -1), dim=1, keepdim=True) + 1e-12
+ diff_normalized = diff / norm.unsqueeze(-1) if diff.ndim > 2 else diff / norm
+ mems_p[layer_idx] = mems[layer_idx] + lyap_eps * diff_normalized
+
+ # Accumulate log-divergence
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ logits = self.readout(spike_sum)
+
+ if compute_lyapunov:
+ # Average over time and batch
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ if record_states:
+ recordings = {
+ "spikes": torch.stack(spike_rec, dim=1), # (B, T, H)
+ "membrane": torch.stack(mem_rec, dim=1), # (B, T, H)
+ }
+ else:
+ recordings = None
+
+ return logits, lyap_est, recordings
+
+
+class RecurrentLyapunovSNN(nn.Module):
+ """
+ Recurrent SNN with Lyapunov exponent computation.
+
+ Uses snnTorch's RSynaptic (recurrent synaptic) neurons for
+ richer temporal dynamics.
+
+ Args:
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes
+ num_classes: Number of output classes
+ alpha: Synaptic current decay rate
+ beta: Membrane potential decay rate
+ threshold: Firing threshold
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ alpha: float = 0.9,
+ beta: float = 0.85,
+ threshold: float = 1.0,
+ spike_grad: Optional[Any] = None,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.hidden_dims = hidden_dims
+ self.num_layers = len(hidden_dims)
+ self.alpha = alpha
+ self.beta = beta
+
+ # Build layers with recurrent synaptic neurons
+ self.linears = nn.ModuleList()
+ self.neurons = nn.ModuleList()
+
+ dims = [input_dim] + hidden_dims
+ for i in range(self.num_layers):
+ self.linears.append(nn.Linear(dims[i], dims[i + 1]))
+ self.neurons.append(
+ snn.RSynaptic(
+ alpha=alpha,
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ reset_mechanism="subtract",
+ all_to_all=True,
+ linear_features=dims[i + 1],
+ )
+ )
+
+ self.readout = nn.Linear(hidden_dims[-1], num_classes)
+ self._init_weights()
+
+ def _init_weights(self):
+ for lin in self.linears:
+ nn.init.xavier_uniform_(lin.weight)
+ nn.init.zeros_(lin.bias)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ def _init_states(self, batch_size: int, device, dtype):
+ """Initialize synaptic currents and membrane potentials."""
+ syns = []
+ mems = []
+ for dim in self.hidden_dims:
+ syns.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ mems.append(torch.zeros(batch_size, dim, device=device, dtype=dtype))
+ return syns, mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Forward pass with optional Lyapunov computation."""
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+
+ syns, mems = self._init_states(B, device, dtype)
+ spike_sum = torch.zeros(B, self.hidden_dims[-1], device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ # Perturb both synaptic currents and membrane potentials
+ syns_p = [s + lyap_eps * torch.randn_like(s) for s in syns]
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ for t in range(T):
+ x_t = x[:, t, :]
+
+ # Nominal trajectory
+ h = x_t
+ new_syns, new_mems = [], []
+ for i in range(self.num_layers):
+ h = self.linears[i](h)
+ spk, syn, mem = self.neurons[i](h, syns[i], mems[i])
+ new_syns.append(syn)
+ new_mems.append(mem)
+ h = spk
+ syns, mems = new_syns, new_mems
+ spike_sum = spike_sum + h
+
+ if compute_lyapunov:
+ # Perturbed trajectory
+ h_p = x_t
+ new_syns_p, new_mems_p = [], []
+ for i in range(self.num_layers):
+ h_p = self.linears[i](h_p)
+ spk_p, syn_p, mem_p = self.neurons[i](h_p, syns_p[i], mems_p[i])
+ new_syns_p.append(syn_p)
+ new_mems_p.append(mem_p)
+ h_p = spk_p
+
+ # Compute divergence (on membrane potentials)
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(self.num_layers):
+ diff_m = new_mems_p[i] - new_mems[i]
+ diff_s = new_syns_p[i] - new_syns[i]
+ delta_sq += (diff_m ** 2).sum(dim=1) + (diff_s ** 2).sum(dim=1)
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize perturbation
+ total_dim = sum(2 * d for d in self.hidden_dims) # syn + mem
+ scale = lyap_eps / (delta.unsqueeze(-1) + 1e-12)
+
+ syns_p = [new_syns[i] + scale * (new_syns_p[i] - new_syns[i])
+ for i in range(self.num_layers)]
+ mems_p = [new_mems[i] + scale * (new_mems_p[i] - new_mems[i])
+ for i in range(self.num_layers)]
+
+ logits = self.readout(spike_sum)
+
+ if compute_lyapunov:
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ return logits, lyap_est
+
+
+def create_snn(
+ model_type: str,
+ input_dim: int,
+ hidden_dims: List[int],
+ num_classes: int,
+ **kwargs,
+) -> nn.Module:
+ """
+ Factory function to create SNN models.
+
+ Args:
+ model_type: "feedforward" or "recurrent"
+ input_dim: Input feature dimension
+ hidden_dims: List of hidden layer sizes
+ num_classes: Number of output classes
+ **kwargs: Additional arguments passed to model constructor
+
+ Returns:
+ SNN model instance
+ """
+ if model_type == "feedforward":
+ return LyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs)
+ elif model_type == "recurrent":
+ return RecurrentLyapunovSNN(input_dim, hidden_dims, num_classes, **kwargs)
+ else:
+ raise ValueError(f"Unknown model_type: {model_type}. Use 'feedforward' or 'recurrent'")