1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class SurrogateStep(torch.autograd.Function):
"""
Heaviside with a smooth surrogate gradient (fast sigmoid).
"""
@staticmethod
def forward(ctx, x: torch.Tensor, alpha: float):
ctx.save_for_backward(x)
ctx.alpha = alpha
return (x > 0).to(x.dtype)
@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
alpha = ctx.alpha
# d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid)
# Use fast sigmoid: s = 1 / (1 + |alpha*x|)
s = 1.0 / (1.0 + (alpha * x).abs())
grad = grad_output * s * s
return grad, None
def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor:
return SurrogateStep.apply(x, alpha)
class LIFLayer(nn.Module):
"""
Single LIF layer without recurrent synapses between neurons.
Dynamics per neuron i:
v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1}
s_t = H( v_t - v_th ) with surrogate gradient
"""
def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
super().__init__()
self.linear = nn.Linear(input_dim, hidden_dim, bias=True)
self.v_threshold = float(v_threshold)
self.decay = float(decay)
self.spike_alpha = float(spike_alpha)
self.rec_strength = float(rec_strength)
self.rec = None
if self.rec_strength != 0.0:
self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False)
nn.init.xavier_uniform_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
if self.rec is not None:
nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale)
def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H)
I_t = self.linear(x_t) # (B, H)
R_t = 0.0
if self.rec is not None:
R_t = self.rec_strength * self.rec(s_prev)
v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev
s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha)
return v_t, s_t
class SimpleSNN(nn.Module):
"""
Minimal SNN for SHD-like input (B,T,D):
- One LIF hidden layer
- Readout linear on time-summed spikes
"""
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
super().__init__()
self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale)
self.readout = nn.Linear(hidden_dim, num_classes)
nn.init.xavier_uniform_(self.readout.weight)
nn.init.zeros_(self.readout.bias)
@torch.no_grad()
def _init_states(self, batch_size: int, hidden_dim: int, device, dtype):
v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
return v0, s0
def forward(
self,
x: torch.Tensor,
compute_lyapunov: bool = False,
lyap_eps: float = 1e-3,
lyap_safe_eps: float = 1e-8,
lyap_measure: str = "v", # "v" or "s"
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
x: (B, T, D)
Returns:
logits: (B, C)
lyap_est: scalar tensor if compute_lyapunov else None
"""
assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}"
B, T, D = x.shape
device, dtype = x.device, x.dtype
H = self.readout.in_features
v, s = self._init_states(B, H, device, dtype)
spike_sum = torch.zeros(B, H, device=device, dtype=dtype)
if compute_lyapunov:
v_p = v + lyap_eps * torch.randn_like(v)
s_p = s.clone()
delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
lyap_terms = []
for t in range(T):
x_t = x[:, t, :]
v, s = self.lif(x_t, v, s)
spike_sum = spike_sum + s
if compute_lyapunov:
# run perturbed trajectory through same ops
v_p, s_p = self.lif(x_t, v_p, s_p)
if lyap_measure == "s":
delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps
else:
delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
ratio = delta_t / delta_prev
lyap_terms.append(torch.log(ratio + lyap_safe_eps))
delta_prev = delta_t
logits = self.readout(spike_sum) # (B, C)
if compute_lyapunov:
lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0) # (B,)
lyap_est = lyap_batch.mean() # scalar
else:
lyap_est = None
return logits, lyap_est
|