diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
| commit | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch) | |
| tree | 59a233959932ca0e4f12f196275e07fcf443b33f /files/models/snn.py | |
init commit
Diffstat (limited to 'files/models/snn.py')
| -rw-r--r-- | files/models/snn.py | 141 |
1 files changed, 141 insertions, 0 deletions
diff --git a/files/models/snn.py b/files/models/snn.py new file mode 100644 index 0000000..b1cf633 --- /dev/null +++ b/files/models/snn.py @@ -0,0 +1,141 @@ +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SurrogateStep(torch.autograd.Function): + """ + Heaviside with a smooth surrogate gradient (fast sigmoid). + """ + @staticmethod + def forward(ctx, x: torch.Tensor, alpha: float): + ctx.save_for_backward(x) + ctx.alpha = alpha + return (x > 0).to(x.dtype) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + alpha = ctx.alpha + # d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid) + # Use fast sigmoid: s = 1 / (1 + |alpha*x|) + s = 1.0 / (1.0 + (alpha * x).abs()) + grad = grad_output * s * s + return grad, None + + +def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor: + return SurrogateStep.apply(x, alpha) + + +class LIFLayer(nn.Module): + """ + Single LIF layer without recurrent synapses between neurons. + Dynamics per neuron i: + v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1} + s_t = H( v_t - v_th ) with surrogate gradient + """ + def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0): + super().__init__() + self.linear = nn.Linear(input_dim, hidden_dim, bias=True) + self.v_threshold = float(v_threshold) + self.decay = float(decay) + self.spike_alpha = float(spike_alpha) + self.rec_strength = float(rec_strength) + self.rec = None + if self.rec_strength != 0.0: + self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False) + + nn.init.xavier_uniform_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + if self.rec is not None: + nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale) + + def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H) + I_t = self.linear(x_t) # (B, H) + R_t = 0.0 + if self.rec is not None: + R_t = self.rec_strength * self.rec(s_prev) + v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev + s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha) + return v_t, s_t + + +class SimpleSNN(nn.Module): + """ + Minimal SNN for SHD-like input (B,T,D): + - One LIF hidden layer + - Readout linear on time-summed spikes + """ + def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0): + super().__init__() + self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale) + self.readout = nn.Linear(hidden_dim, num_classes) + nn.init.xavier_uniform_(self.readout.weight) + nn.init.zeros_(self.readout.bias) + + @torch.no_grad() + def _init_states(self, batch_size: int, hidden_dim: int, device, dtype): + v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) + s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype) + return v0, s0 + + def forward( + self, + x: torch.Tensor, + compute_lyapunov: bool = False, + lyap_eps: float = 1e-3, + lyap_safe_eps: float = 1e-8, + lyap_measure: str = "v", # "v" or "s" + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + x: (B, T, D) + Returns: + logits: (B, C) + lyap_est: scalar tensor if compute_lyapunov else None + """ + assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}" + B, T, D = x.shape + device, dtype = x.device, x.dtype + H = self.readout.in_features + + v, s = self._init_states(B, H, device, dtype) + spike_sum = torch.zeros(B, H, device=device, dtype=dtype) + + if compute_lyapunov: + v_p = v + lyap_eps * torch.randn_like(v) + s_p = s.clone() + delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps + lyap_terms = [] + + for t in range(T): + x_t = x[:, t, :] + v, s = self.lif(x_t, v, s) + spike_sum = spike_sum + s + + if compute_lyapunov: + # run perturbed trajectory through same ops + v_p, s_p = self.lif(x_t, v_p, s_p) + if lyap_measure == "s": + delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps + else: + delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps + ratio = delta_t / delta_prev + lyap_terms.append(torch.log(ratio + lyap_safe_eps)) + delta_prev = delta_t + + logits = self.readout(spike_sum) # (B, C) + + if compute_lyapunov: + lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0) # (B,) + lyap_est = lyap_batch.mean() # scalar + else: + lyap_est = None + + return logits, lyap_est + + |
