import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class SurrogateStep(torch.autograd.Function): """ Heaviside with a smooth surrogate gradient (fast sigmoid). """ @staticmethod def forward(ctx, x: torch.Tensor, alpha: float): ctx.save_for_backward(x) ctx.alpha = alpha return (x > 0).to(x.dtype) @staticmethod def backward(ctx, grad_output): (x,) = ctx.saved_tensors alpha = ctx.alpha # d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid) # Use fast sigmoid: s = 1 / (1 + |alpha*x|) s = 1.0 / (1.0 + (alpha * x).abs()) grad = grad_output * s * s return grad, None def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor: return SurrogateStep.apply(x, alpha) class LIFLayer(nn.Module): """ Single LIF layer without recurrent synapses between neurons. Dynamics per neuron i: v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1} s_t = H( v_t - v_th ) with surrogate gradient """ def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0): super().__init__() self.linear = nn.Linear(input_dim, hidden_dim, bias=True) self.v_threshold = float(v_threshold) self.decay = float(decay) self.spike_alpha = float(spike_alpha) self.rec_strength = float(rec_strength) self.rec = None if self.rec_strength != 0.0: self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False) nn.init.xavier_uniform_(self.linear.weight) nn.init.zeros_(self.linear.bias) if self.rec is not None: nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale) def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H) I_t = self.linear(x_t) # (B, H) R_t = 0.0 if self.rec is not None: R_t = self.rec_strength * self.rec(s_prev) v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha) return v_t, s_t class SimpleSNN(nn.Module): """ Minimal SNN for SHD-like input (B,T,D): - One LIF hidden layer - Readout linear on time-summed spikes """ def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0): super().__init__() self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale) self.readout = nn.Linear(hidden_dim, num_classes) nn.init.xavier_uniform_(self.readout.weight) nn.init.zeros_(self.readout.bias) @torch.no_grad() def _init_states(self, batch_size: int, hidden_dim: int, device, dtype): v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) return v0, s0 def forward( self, x: torch.Tensor, compute_lyapunov: bool = False, lyap_eps: float = 1e-3, lyap_safe_eps: float = 1e-8, lyap_measure: str = "v", # "v" or "s" ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ x: (B, T, D) Returns: logits: (B, C) lyap_est: scalar tensor if compute_lyapunov else None """ assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}" B, T, D = x.shape device, dtype = x.device, x.dtype H = self.readout.in_features v, s = self._init_states(B, H, device, dtype) spike_sum = torch.zeros(B, H, device=device, dtype=dtype) if compute_lyapunov: v_p = v + lyap_eps * torch.randn_like(v) s_p = s.clone() delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps lyap_terms = [] for t in range(T): x_t = x[:, t, :] v, s = self.lif(x_t, v, s) spike_sum = spike_sum + s if compute_lyapunov: # run perturbed trajectory through same ops v_p, s_p = self.lif(x_t, v_p, s_p) if lyap_measure == "s": delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps else: delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps ratio = delta_t / delta_prev lyap_terms.append(torch.log(ratio + lyap_safe_eps)) delta_prev = delta_t logits = self.readout(spike_sum) # (B, C) if compute_lyapunov: lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0) # (B,) lyap_est = lyap_batch.mean() # scalar else: lyap_est = None return logits, lyap_est