From cd99d6b874d9d09b3bb87b8485cc787885af71f1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 13 Jan 2026 23:49:05 -0600 Subject: init commit --- files/models/conv_snn.py | 483 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 483 insertions(+) create mode 100644 files/models/conv_snn.py (limited to 'files/models/conv_snn.py') diff --git a/files/models/conv_snn.py b/files/models/conv_snn.py new file mode 100644 index 0000000..69f77d6 --- /dev/null +++ b/files/models/conv_snn.py @@ -0,0 +1,483 @@ +""" +Convolutional SNN with Lyapunov regularization for image classification. + +Properly handles spatial structure: +- Input: (B, C, H, W) static image OR (B, T, C, H, W) spike tensor +- Uses Conv-LIF layers to preserve spatial hierarchy +- Rate encoding converts images to spike trains + +Based on standard SNN vision practices: +- Rate/Poisson encoding for input +- Conv → BatchNorm → LIF → Pool architecture +- Time comes from encoding + LIF dynamics, not flattening +""" + +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import snntorch as snn +from snntorch import surrogate + + +class RateEncoder(nn.Module): + """ + Rate (Poisson/Bernoulli) encoder for static images. + + Converts intensity x ∈ [0,1] to spike probability per timestep. + Each pixel independently fires with P(spike) = x * gain. + + Args: + T: Number of timesteps + gain: Scaling factor for firing probability (default 1.0) + + Input: (B, C, H, W) normalized image in [0, 1] + Output: (B, T, C, H, W) binary spike tensor + """ + + def __init__(self, T: int = 25, gain: float = 1.0): + super().__init__() + self.T = T + self.gain = gain + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, C, H, W) image tensor, values in [0, 1] + Returns: + spikes: (B, T, C, H, W) binary spike tensor + """ + # Clamp to valid probability range + prob = (x * self.gain).clamp(0, 1) + + # Expand for T timesteps: (B, C, H, W) -> (B, T, C, H, W) + prob = prob.unsqueeze(1).expand(-1, self.T, -1, -1, -1) + + # Sample spikes + spikes = torch.bernoulli(prob) + + return spikes + + +class DirectEncoder(nn.Module): + """ + Direct encoding - feed static image as constant current. + + Common in surrogate gradient papers: no spike encoding at input, + let spiking emerge from the network dynamics. + + Input: (B, C, H, W) image + Output: (B, T, C, H, W) repeated image (as analog current) + """ + + def __init__(self, T: int = 25): + super().__init__() + self.T = T + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Simply repeat across time + return x.unsqueeze(1).expand(-1, self.T, -1, -1, -1) + + +class ConvLIFBlock(nn.Module): + """ + Conv → BatchNorm → LIF block. + + Maintains spatial structure while adding spiking dynamics. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + beta: float = 0.9, + threshold: float = 1.0, + spike_grad=None, + ): + super().__init__() + + if spike_grad is None: + spike_grad = surrogate.fast_sigmoid(slope=25) + + self.conv = nn.Conv2d( + in_channels, out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.lif = snn.Leaky( + beta=beta, + threshold=threshold, + spike_grad=spike_grad, + init_hidden=False, + ) + + def forward( + self, + x: torch.Tensor, + mem: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: (B, C_in, H, W) input (spikes or current) + mem: (B, C_out, H', W') membrane potential + Returns: + spk: (B, C_out, H', W') output spikes + mem: (B, C_out, H', W') updated membrane + """ + cur = self.bn(self.conv(x)) + spk, mem = self.lif(cur, mem) + return spk, mem + + +class ConvLyapunovSNN(nn.Module): + """ + Convolutional SNN with Lyapunov exponent regularization. + + Architecture for CIFAR-10 (32x32x3): + Input → Encoder → [Conv-LIF-Pool] × N → FC → Output + + Properly preserves spatial structure for hierarchical feature learning. + + Args: + in_channels: Input channels (3 for RGB) + num_classes: Output classes + channels: List of channel sizes for conv layers + T: Number of timesteps + beta: LIF membrane decay + threshold: LIF firing threshold + encoding: 'rate', 'direct', or 'none' (pre-encoded input) + encoding_gain: Gain for rate encoding + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + channels: List[int] = [64, 128, 256], + T: int = 25, + beta: float = 0.9, + threshold: float = 1.0, + encoding: str = 'rate', + encoding_gain: float = 1.0, + dropout: float = 0.2, + ): + super().__init__() + + self.T = T + self.encoding_type = encoding + self.channels = channels + self.num_layers = len(channels) + + # Input encoder + if encoding == 'rate': + self.encoder = RateEncoder(T=T, gain=encoding_gain) + elif encoding == 'direct': + self.encoder = DirectEncoder(T=T) + else: + self.encoder = None # Expect pre-encoded (B, T, C, H, W) input + + # Build conv-LIF layers + self.blocks = nn.ModuleList() + self.pools = nn.ModuleList() + + ch_in = in_channels + for ch_out in channels: + self.blocks.append( + ConvLIFBlock(ch_in, ch_out, beta=beta, threshold=threshold) + ) + self.pools.append(nn.AvgPool2d(2)) + ch_in = ch_out + + # Calculate output spatial size after pooling + # CIFAR: 32 -> 16 -> 8 -> 4 (for 3 layers) + spatial_size = 32 // (2 ** len(channels)) + fc_input = channels[-1] * spatial_size * spatial_size + + # Fully connected readout + self.dropout = nn.Dropout(dropout) + self.fc = nn.Linear(fc_input, num_classes) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _init_mem(self, batch_size: int, device, dtype) -> List[torch.Tensor]: + """Initialize membrane potentials for all layers.""" + mems = [] + H, W = 32, 32 + for i, ch in enumerate(self.channels): + H, W = H // 2, W // 2 # After pooling + # Actually we need size BEFORE pooling for LIF + H_pre, W_pre = H * 2, W * 2 + mems.append(torch.zeros(batch_size, ch, H_pre, W_pre, device=device, dtype=dtype)) + return mems + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + """ + Forward pass with optional Lyapunov computation. + + Args: + x: Input tensor + - If encoder: (B, C, H, W) static image + - If no encoder: (B, T, C, H, W) pre-encoded spikes + compute_lyapunov: Whether to compute Lyapunov exponent + lyap_eps: Perturbation magnitude + + Returns: + logits: (B, num_classes) + lyap_est: Scalar Lyapunov estimate or None + recordings: Optional dict with spike recordings + """ + # Encode input if needed + if self.encoder is not None: + x = self.encoder(x) # (B, C, H, W) -> (B, T, C, H, W) + + B, T, C, H, W = x.shape + device, dtype = x.device, x.dtype + + # Initialize membrane potentials + mems = self._init_mem(B, device, dtype) + + # For accumulating output spikes + spike_sum = None + + # Lyapunov setup + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + # Time loop + for t in range(T): + x_t = x[:, t] # (B, C, H, W) + + # Forward through conv-LIF blocks + h = x_t + new_mems = [] + for i, (block, pool) in enumerate(zip(self.blocks, self.pools)): + h, mem = block(h, mems[i]) + new_mems.append(mem) + h = pool(h) # Spatial downsampling + + mems = new_mems + + # Accumulate final layer spikes + if spike_sum is None: + spike_sum = h.view(B, -1) + else: + spike_sum = spike_sum + h.view(B, -1) + + # Lyapunov computation + if compute_lyapunov: + h_p = x_t + new_mems_p = [] + for i, (block, pool) in enumerate(zip(self.blocks, self.pools)): + h_p, mem_p = block(h_p, mems_p[i]) + new_mems_p.append(mem_p) + h_p = pool(h_p) + + # Compute divergence + delta_sq = torch.zeros(B, device=device, dtype=dtype) + for i in range(self.num_layers): + diff = new_mems_p[i] - new_mems[i] + delta_sq += (diff ** 2).sum(dim=(1, 2, 3)) + + delta = torch.sqrt(delta_sq + 1e-12) + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + + # Renormalize perturbation + for i in range(self.num_layers): + diff = new_mems_p[i] - new_mems[i] + norm = torch.sqrt((diff ** 2).sum(dim=(1, 2, 3), keepdim=True) + 1e-12) + # Broadcast norm to spatial dimensions + norm = norm.view(B, 1, 1, 1) + new_mems_p[i] = new_mems[i] + lyap_eps * diff / norm + + mems_p = new_mems_p + + # Readout + out = self.dropout(spike_sum) + logits = self.fc(out) + + if compute_lyapunov: + lyap_est = (lyap_accum / T).mean() + else: + lyap_est = None + + return logits, lyap_est, None + + +class VGGLyapunovSNN(nn.Module): + """ + VGG-style deep Conv-SNN with Lyapunov regularization. + + Deeper architecture for more challenging benchmarks. + Uses multiple conv layers between pooling to increase depth. + + Architecture (VGG-9 style): + [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → [Conv-LIF × 2, Pool] → FC + """ + + def __init__( + self, + in_channels: int = 3, + num_classes: int = 10, + T: int = 25, + beta: float = 0.9, + threshold: float = 1.0, + encoding: str = 'rate', + dropout: float = 0.3, + ): + super().__init__() + + self.T = T + self.encoding_type = encoding + + spike_grad = surrogate.fast_sigmoid(slope=25) + + if encoding == 'rate': + self.encoder = RateEncoder(T=T) + elif encoding == 'direct': + self.encoder = DirectEncoder(T=T) + else: + self.encoder = None + + # VGG-style blocks: (in_ch, out_ch, num_convs) + block_configs = [ + (in_channels, 64, 2), # 32x32 -> 16x16 + (64, 128, 2), # 16x16 -> 8x8 + (128, 256, 2), # 8x8 -> 4x4 + ] + + self.blocks = nn.ModuleList() + for in_ch, out_ch, n_convs in block_configs: + layers = [] + for i in range(n_convs): + ch_in = in_ch if i == 0 else out_ch + layers.append(nn.Conv2d(ch_in, out_ch, 3, padding=1, bias=False)) + layers.append(nn.BatchNorm2d(out_ch)) + self.blocks.append(nn.ModuleList(layers)) + + # LIF neurons for each conv layer + self.lifs = nn.ModuleList([ + snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + for _ in range(6) # 2 convs × 3 blocks + ]) + + self.pools = nn.ModuleList([nn.AvgPool2d(2) for _ in range(3)]) + + # FC layers + self.fc1 = nn.Linear(256 * 4 * 4, 512) + self.lif_fc = snn.Leaky(beta=beta, threshold=threshold, spike_grad=spike_grad, init_hidden=False) + self.dropout = nn.Dropout(dropout) + self.fc2 = nn.Linear(512, num_classes) + + self._init_weights() + + def _init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-4, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Dict]]: + + if self.encoder is not None: + x = self.encoder(x) + + B, T, C, H, W = x.shape + device, dtype = x.device, x.dtype + + # Initialize all membrane potentials + # For each conv layer output + mem_shapes = [ + (B, 64, 32, 32), (B, 64, 32, 32), # Block 1 + (B, 128, 16, 16), (B, 128, 16, 16), # Block 2 + (B, 256, 8, 8), (B, 256, 8, 8), # Block 3 + (B, 512), # FC + ] + mems = [torch.zeros(s, device=device, dtype=dtype) for s in mem_shapes] + + spike_sum = torch.zeros(B, 512, device=device, dtype=dtype) + + if compute_lyapunov: + mems_p = [m + lyap_eps * torch.randn_like(m) for m in mems] + lyap_accum = torch.zeros(B, device=device, dtype=dtype) + + for t in range(T): + h = x[:, t] + + lif_idx = 0 + for block_idx, (block_layers, pool) in enumerate(zip(self.blocks, self.pools)): + for i in range(0, len(block_layers), 2): # Conv, BN pairs + conv, bn = block_layers[i], block_layers[i + 1] + h = bn(conv(h)) + h, mems[lif_idx] = self.lifs[lif_idx](h, mems[lif_idx]) + lif_idx += 1 + h = pool(h) + + # FC layers + h = h.view(B, -1) + h = self.fc1(h) + h, mems[6] = self.lif_fc(h, mems[6]) + spike_sum = spike_sum + h + + # Lyapunov (simplified - just on last layer) + if compute_lyapunov: + diff = mems[6] - mems_p[6] if t > 0 else torch.zeros_like(mems[6]) + delta = torch.norm(diff.view(B, -1), dim=1) + 1e-12 + if t > 0: + lyap_accum = lyap_accum + torch.log(delta / lyap_eps + 1e-12) + mems_p[6] = mems[6] + lyap_eps * torch.randn_like(mems[6]) + + out = self.dropout(spike_sum) + logits = self.fc2(out) + + lyap_est = (lyap_accum / T).mean() if compute_lyapunov else None + + return logits, lyap_est, None + + +def create_conv_snn( + model_type: str = 'simple', + **kwargs, +) -> nn.Module: + """ + Factory function for Conv-SNN models. + + Args: + model_type: 'simple' (3-layer) or 'vgg' (6-layer VGG-style) + **kwargs: Arguments passed to model constructor + """ + if model_type == 'simple': + return ConvLyapunovSNN(**kwargs) + elif model_type == 'vgg': + return VGGLyapunovSNN(**kwargs) + else: + raise ValueError(f"Unknown model_type: {model_type}") -- cgit v1.2.3