summaryrefslogtreecommitdiff
path: root/files/models/conv_snn.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
commitcd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch)
tree59a233959932ca0e4f12f196275e07fcf443b33f /files/models/conv_snn.py
init commit
Diffstat (limited to 'files/models/conv_snn.py')
-rw-r--r--files/models/conv_snn.py483
1 files changed, 483 insertions, 0 deletions
diff --git a/files/models/conv_snn.py b/files/models/conv_snn.py
new file mode 100644
index 0000000..69f77d6
--- /dev/null
+++ b/files/models/conv_snn.py
@@ -0,0 +1,483 @@
+"""
+Convolutional SNN with Lyapunov regularization for image classification.
+
+Properly handles spatial structure:
+- Input: (B, C, H, W) static image OR (B, T, C, H, W) spike tensor
+- Uses Conv-LIF layers to preserve spatial hierarchy
+- Rate encoding converts images to spike trains
+
+Based on standard SNN vision practices:
+- Rate/Poisson encoding for input
+- Conv → BatchNorm → LIF → Pool architecture
+- Time comes from encoding + LIF dynamics, not flattening
+"""
+
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import snntorch as snn
+from snntorch import surrogate
+
+
+class RateEncoder(nn.Module):
+ """
+ Rate (Poisson/Bernoulli) encoder for static images.
+
+ Converts intensity x ∈ [0,1] to spike probability per timestep.
+ Each pixel independently fires with P(spike) = x * gain.
+
+ Args:
+ T: Number of timesteps
+ gain: Scaling factor for firing probability (default 1.0)
+
+ Input: (B, C, H, W) normalized image in [0, 1]
+ Output: (B, T, C, H, W) binary spike tensor
+ """
+
+ def __init__(self, T: int = 25, gain: float = 1.0):
+ super().__init__()
+ self.T = T
+ self.gain = gain
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: (B, C, H, W) image tensor, values in [0, 1]
+ Returns:
+ spikes: (B, T, C, H, W) binary spike tensor
+ """
+ # Clamp to valid probability range
+ prob = (x * self.gain).clamp(0, 1)
+
+ # Expand for T timesteps: (B, C, H, W) -> (B, T, C, H, W)
+ prob = prob.unsqueeze(1).expand(-1, self.T, -1, -1, -1)
+
+ # Sample spikes
+ spikes = torch.bernoulli(prob)
+
+ return spikes
+
+
+class DirectEncoder(nn.Module):
+ """
+ Direct encoding - feed static image as constant current.
+
+ Common in surrogate gradient papers: no spike encoding at input,
+ let spiking emerge from the network dynamics.
+
+ Input: (B, C, H, W) image
+ Output: (B, T, C, H, W) repeated image (as analog current)
+ """
+
+ def __init__(self, T: int = 25):
+ super().__init__()
+ self.T = T
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Simply repeat across time
+ return x.unsqueeze(1).expand(-1, self.T, -1, -1, -1)
+
+
+class ConvLIFBlock(nn.Module):
+ """
+ Conv → BatchNorm → LIF block.
+
+ Maintains spatial structure while adding spiking dynamics.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 1,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ spike_grad=None,
+ ):
+ super().__init__()
+
+ if spike_grad is None:
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ self.conv = nn.Conv2d(
+ in_channels, out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.lif = snn.Leaky(
+ beta=beta,
+ threshold=threshold,
+ spike_grad=spike_grad,
+ init_hidden=False,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mem: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x: (B, C_in, H, W) input (spikes or current)
+ mem: (B, C_out, H', W') membrane potential
+ Returns:
+ spk: (B, C_out, H', W') output spikes
+ mem: (B, C_out, H', W') updated membrane
+ """
+ cur = self.bn(self.conv(x))
+ spk, mem = self.lif(cur, mem)
+ return spk, mem
+
+
+class ConvLyapunovSNN(nn.Module):
+ """
+ Convolutional SNN with Lyapunov exponent regularization.
+
+ Architecture for CIFAR-10 (32x32x3):
+ Input → Encoder → [Conv-LIF-Pool] × N → FC → Output
+
+ Properly preserves spatial structure for hierarchical feature learning.
+
+ Args:
+ in_channels: Input channels (3 for RGB)
+ num_classes: Output classes
+ channels: List of channel sizes for conv layers
+ T: Number of timesteps
+ beta: LIF membrane decay
+ threshold: LIF firing threshold
+ encoding: 'rate', 'direct', or 'none' (pre-encoded input)
+ encoding_gain: Gain for rate encoding
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ channels: List[int] = [64, 128, 256],
+ T: int = 25,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ encoding: str = 'rate',
+ encoding_gain: float = 1.0,
+ dropout: float = 0.2,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.encoding_type = encoding
+ self.channels = channels
+ self.num_layers = len(channels)
+
+ # Input encoder
+ if encoding == 'rate':
+ self.encoder = RateEncoder(T=T, gain=encoding_gain)
+ elif encoding == 'direct':
+ self.encoder = DirectEncoder(T=T)
+ else:
+ self.encoder = None # Expect pre-encoded (B, T, C, H, W) input
+
+ # Build conv-LIF layers
+ self.blocks = nn.ModuleList()
+ self.pools = nn.ModuleList()
+
+ ch_in = in_channels
+ for ch_out in channels:
+ self.blocks.append(
+ ConvLIFBlock(ch_in, ch_out, beta=beta, threshold=threshold)
+ )
+ self.pools.append(nn.AvgPool2d(2))
+ ch_in = ch_out
+
+ # Calculate output spatial size after pooling
+ # CIFAR: 32 -> 16 -> 8 -> 4 (for 3 layers)
+ spatial_size = 32 // (2 ** len(channels))
+ fc_input = channels[-1] * spatial_size * spatial_size
+
+ # Fully connected readout
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(fc_input, num_classes)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _init_mem(self, batch_size: int, device, dtype) -> List[torch.Tensor]:
+ """Initialize membrane potentials for all layers."""
+ mems = []
+ H, W = 32, 32
+ for i, ch in enumerate(self.channels):
+ H, W = H // 2, W // 2 # After pooling
+ # Actually we need size BEFORE pooling for LIF
+ H_pre, W_pre = H * 2, W * 2
+ mems.append(torch.zeros(batch_size, ch, H_pre, W_pre, device=device, dtype=dtype))
+ return mems
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+ """
+ Forward pass with optional Lyapunov computation.
+
+ Args:
+ x: Input tensor
+ - If encoder: (B, C, H, W) static image
+ - If no encoder: (B, T, C, H, W) pre-encoded spikes
+ compute_lyapunov: Whether to compute Lyapunov exponent
+ lyap_eps: Perturbation magnitude
+
+ Returns:
+ logits: (B, num_classes)
+ lyap_est: Scalar Lyapunov estimate or None
+ recordings: Optional dict with spike recordings
+ """
+ # Encode input if needed
+ if self.encoder is not None:
+ x = self.encoder(x) # (B, C, H, W) -> (B, T, C, H, W)
+
+ B, T, C, H, W = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize membrane potentials
+ mems = self._init_mem(B, device, dtype)
+
+ # For accumulating output spikes
+ spike_sum = None
+
+ # Lyapunov setup
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ # Time loop
+ for t in range(T):
+ x_t = x[:, t] # (B, C, H, W)
+
+ # Forward through conv-LIF blocks
+ h = x_t
+ new_mems = []
+ for i, (block, pool) in enumerate(zip(self.blocks, self.pools)):
+ h, mem = block(h, mems[i])
+ new_mems.append(mem)
+ h = pool(h) # Spatial downsampling
+
+ mems = new_mems
+
+ # Accumulate final layer spikes
+ if spike_sum is None:
+ spike_sum = h.view(B, -1)
+ else:
+ spike_sum = spike_sum + h.view(B, -1)
+
+ # Lyapunov computation
+ if compute_lyapunov:
+ h_p = x_t
+ new_mems_p = []
+ for i, (block, pool) in enumerate(zip(self.blocks, self.pools)):
+ h_p, mem_p = block(h_p, mems_p[i])
+ new_mems_p.append(mem_p)
+ h_p = pool(h_p)
+
+ # Compute divergence
+ delta_sq = torch.zeros(B, device=device, dtype=dtype)
+ for i in range(self.num_layers):
+ diff = new_mems_p[i] - new_mems[i]
+ delta_sq += (diff ** 2).sum(dim=(1, 2, 3))
+
+ delta = torch.sqrt(delta_sq + 1e-12)
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+
+ # Renormalize perturbation
+ for i in range(self.num_layers):
+ diff = new_mems_p[i] - new_mems[i]
+ norm = torch.sqrt((diff ** 2).sum(dim=(1, 2, 3), keepdim=True) + 1e-12)
+ # Broadcast norm to spatial dimensions
+ norm = norm.view(B, 1, 1, 1)
+ new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm
+
+ mems_p = new_mems_p
+
+ # Readout
+ out = self.dropout(spike_sum)
+ logits = self.fc(out)
+
+ if compute_lyapunov:
+ lyap_est = (lyap_accum / T).mean()
+ else:
+ lyap_est = None
+
+ return logits, lyap_est, None
+
+
+class VGGLyapunovSNN(nn.Module):
+ """
+ VGG-style deep Conv-SNN with Lyapunov regularization.
+
+ Deeper architecture for more challenging benchmarks.
+ Uses multiple conv layers between pooling to increase depth.
+
+ Architecture (VGG-9 style):
+ [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → FC
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ num_classes: int = 10,
+ T: int = 25,
+ beta: float = 0.9,
+ threshold: float = 1.0,
+ encoding: str = 'rate',
+ dropout: float = 0.3,
+ ):
+ super().__init__()
+
+ self.T = T
+ self.encoding_type = encoding
+
+ spike_grad = surrogate.fast_sigmoid(slope=25)
+
+ if encoding == 'rate':
+ self.encoder = RateEncoder(T=T)
+ elif encoding == 'direct':
+ self.encoder = DirectEncoder(T=T)
+ else:
+ self.encoder = None
+
+ # VGG-style blocks: (in_ch, out_ch, num_convs)
+ block_configs = [
+ (in_channels, 64, 2), # 32x32 -> 16x16
+ (64, 128, 2), # 16x16 -> 8x8
+ (128, 256, 2), # 8x8 -> 4x4
+ ]
+
+ self.blocks = nn.ModuleList()
+ for in_ch, out_ch, n_convs in block_configs:
+ layers = []
+ for i in range(n_convs):
+ ch_in = in_ch if i == 0 else out_ch
+ layers.append(nn.Conv2d(ch_in, out_ch, 3, padding=1, bias=False))
+ layers.append(nn.BatchNorm2d(out_ch))
+ self.blocks.append(nn.ModuleList(layers))
+
+ # LIF neurons for each conv layer
+ self.lifs = nn.ModuleList([
+ snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+ for _ in range(6) # 2 convs × 3 blocks
+ ])
+
+ self.pools = nn.ModuleList([nn.AvgPool2d(2) for _ in range(3)])
+
+ # FC layers
+ self.fc1 = nn.Linear(256 * 4 * 4, 512)
+ self.lif_fc = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False)
+ self.dropout = nn.Dropout(dropout)
+ self.fc2 = nn.Linear(512, num_classes)
+
+ self._init_weights()
+
+ def _init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-4,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]:
+
+ if self.encoder is not None:
+ x = self.encoder(x)
+
+ B, T, C, H, W = x.shape
+ device, dtype = x.device, x.dtype
+
+ # Initialize all membrane potentials
+ # For each conv layer output
+ mem_shapes = [
+ (B, 64, 32, 32), (B, 64, 32, 32), # Block 1
+ (B, 128, 16, 16), (B, 128, 16, 16), # Block 2
+ (B, 256, 8, 8), (B, 256, 8, 8), # Block 3
+ (B, 512), # FC
+ ]
+ mems = [torch.zeros(s, device=device, dtype=dtype) for s in mem_shapes]
+
+ spike_sum = torch.zeros(B, 512, device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems]
+ lyap_accum = torch.zeros(B, device=device, dtype=dtype)
+
+ for t in range(T):
+ h = x[:, t]
+
+ lif_idx = 0
+ for block_idx, (block_layers, pool) in enumerate(zip(self.blocks, self.pools)):
+ for i in range(0, len(block_layers), 2): # Conv, BN pairs
+ conv, bn = block_layers[i], block_layers[i + 1]
+ h = bn(conv(h))
+ h, mems[lif_idx] = self.lifs[lif_idx](h, mems[lif_idx])
+ lif_idx += 1
+ h = pool(h)
+
+ # FC layers
+ h = h.view(B, -1)
+ h = self.fc1(h)
+ h, mems[6] = self.lif_fc(h, mems[6])
+ spike_sum = spike_sum + h
+
+ # Lyapunov (simplified - just on last layer)
+ if compute_lyapunov:
+ diff = mems[6] - mems_p[6] if t > 0 else torch.zeros_like(mems[6])
+ delta = torch.norm(diff.view(B, -1), dim=1) + 1e-12
+ if t > 0:
+ lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12)
+ mems_p[6] = mems[6] + lyap_eps * torch.randn_like(mems[6])
+
+ out = self.dropout(spike_sum)
+ logits = self.fc2(out)
+
+ lyap_est = (lyap_accum / T).mean() if compute_lyapunov else None
+
+ return logits, lyap_est, None
+
+
+def create_conv_snn(
+ model_type: str = 'simple',
+ **kwargs,
+) -> nn.Module:
+ """
+ Factory function for Conv-SNN models.
+
+ Args:
+ model_type: 'simple' (3-layer) or 'vgg' (6-layer VGG-style)
+ **kwargs: Arguments passed to model constructor
+ """
+ if model_type == 'simple':
+ return ConvLyapunovSNN(**kwargs)
+ elif model_type == 'vgg':
+ return VGGLyapunovSNN(**kwargs)
+ else:
+ raise ValueError(f"Unknown model_type: {model_type}")