summaryrefslogtreecommitdiff
path: root/files/models/snn.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/models/snn.py')
-rw-r--r--files/models/snn.py141
1 files changed, 141 insertions, 0 deletions
diff --git a/files/models/snn.py b/files/models/snn.py
new file mode 100644
index 0000000..b1cf633
--- /dev/null
+++ b/files/models/snn.py
@@ -0,0 +1,141 @@
+import math
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SurrogateStep(torch.autograd.Function):
+ """
+ Heaviside with a smooth surrogate gradient (fast sigmoid).
+ """
+ @staticmethod
+ def forward(ctx, x: torch.Tensor, alpha: float):
+ ctx.save_for_backward(x)
+ ctx.alpha = alpha
+ return (x > 0).to(x.dtype)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ (x,) = ctx.saved_tensors
+ alpha = ctx.alpha
+ # d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid)
+ # Use fast sigmoid: s = 1 / (1 + |alpha*x|)
+ s = 1.0 / (1.0 + (alpha * x).abs())
+ grad = grad_output * s * s
+ return grad, None
+
+
+def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor:
+ return SurrogateStep.apply(x, alpha)
+
+
+class LIFLayer(nn.Module):
+ """
+ Single LIF layer without recurrent synapses between neurons.
+ Dynamics per neuron i:
+ v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1}
+ s_t = H( v_t - v_th ) with surrogate gradient
+ """
+ def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
+ super().__init__()
+ self.linear = nn.Linear(input_dim, hidden_dim, bias=True)
+ self.v_threshold = float(v_threshold)
+ self.decay = float(decay)
+ self.spike_alpha = float(spike_alpha)
+ self.rec_strength = float(rec_strength)
+ self.rec = None
+ if self.rec_strength != 0.0:
+ self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False)
+
+ nn.init.xavier_uniform_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+ if self.rec is not None:
+ nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale)
+
+ def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ # x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H)
+ I_t = self.linear(x_t) # (B, H)
+ R_t = 0.0
+ if self.rec is not None:
+ R_t = self.rec_strength * self.rec(s_prev)
+ v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev
+ s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha)
+ return v_t, s_t
+
+
+class SimpleSNN(nn.Module):
+ """
+ Minimal SNN for SHD-like input (B,T,D):
+ - One LIF hidden layer
+ - Readout linear on time-summed spikes
+ """
+ def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
+ super().__init__()
+ self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale)
+ self.readout = nn.Linear(hidden_dim, num_classes)
+ nn.init.xavier_uniform_(self.readout.weight)
+ nn.init.zeros_(self.readout.bias)
+
+ @torch.no_grad()
+ def _init_states(self, batch_size: int, hidden_dim: int, device, dtype):
+ v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
+ s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
+ return v0, s0
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ compute_lyapunov: bool = False,
+ lyap_eps: float = 1e-3,
+ lyap_safe_eps: float = 1e-8,
+ lyap_measure: str = "v", # "v" or "s"
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ x: (B, T, D)
+ Returns:
+ logits: (B, C)
+ lyap_est: scalar tensor if compute_lyapunov else None
+ """
+ assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}"
+ B, T, D = x.shape
+ device, dtype = x.device, x.dtype
+ H = self.readout.in_features
+
+ v, s = self._init_states(B, H, device, dtype)
+ spike_sum = torch.zeros(B, H, device=device, dtype=dtype)
+
+ if compute_lyapunov:
+ v_p = v + lyap_eps * torch.randn_like(v)
+ s_p = s.clone()
+ delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
+ lyap_terms = []
+
+ for t in range(T):
+ x_t = x[:, t, :]
+ v, s = self.lif(x_t, v, s)
+ spike_sum = spike_sum + s
+
+ if compute_lyapunov:
+ # run perturbed trajectory through same ops
+ v_p, s_p = self.lif(x_t, v_p, s_p)
+ if lyap_measure == "s":
+ delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps
+ else:
+ delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
+ ratio = delta_t / delta_prev
+ lyap_terms.append(torch.log(ratio + lyap_safe_eps))
+ delta_prev = delta_t
+
+ logits = self.readout(spike_sum) # (B, C)
+
+ if compute_lyapunov:
+ lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0) # (B,)
+ lyap_est = lyap_batch.mean() # scalar
+ else:
+ lyap_est = None
+
+ return logits, lyap_est
+
+