summaryrefslogtreecommitdiff
path: root/files/models/snn.py
blob: b1cf633f7b58f11c025666be09b0c627ca09ee3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


class SurrogateStep(torch.autograd.Function):
    """
    Heaviside with a smooth surrogate gradient (fast sigmoid).
    """
    @staticmethod
    def forward(ctx, x: torch.Tensor, alpha: float):
        ctx.save_for_backward(x)
        ctx.alpha = alpha
        return (x > 0).to(x.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        (x,) = ctx.saved_tensors
        alpha = ctx.alpha
        # d/dx sigmoid(alpha*x) ~ alpha * sigmoid * (1 - sigmoid)
        # Use fast sigmoid: s = 1 / (1 + |alpha*x|)
        s = 1.0 / (1.0 + (alpha * x).abs())
        grad = grad_output * s * s
        return grad, None


def surrogate_heaviside(x: torch.Tensor, alpha: float = 5.0) -> torch.Tensor:
    return SurrogateStep.apply(x, alpha)


class LIFLayer(nn.Module):
    """
    Single LIF layer without recurrent synapses between neurons.
    Dynamics per neuron i:
        v_t = decay * v_{t-1} + W x_t - v_th * s_{t-1}
        s_t = H( v_t - v_th )   with surrogate gradient
    """
    def __init__(self, input_dim: int, hidden_dim: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
        super().__init__()
        self.linear = nn.Linear(input_dim, hidden_dim, bias=True)
        self.v_threshold = float(v_threshold)
        self.decay = float(decay)
        self.spike_alpha = float(spike_alpha)
        self.rec_strength = float(rec_strength)
        self.rec = None
        if self.rec_strength != 0.0:
            self.rec = nn.Linear(hidden_dim, hidden_dim, bias=False)

        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
        if self.rec is not None:
            nn.init.xavier_uniform_(self.rec.weight, gain=rec_init_scale)

    def forward(self, x_t: torch.Tensor, v_prev: torch.Tensor, s_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # x_t: (B, D_in), v_prev: (B, H), s_prev: (B, H)
        I_t = self.linear(x_t)  # (B, H)
        R_t = 0.0
        if self.rec is not None:
            R_t = self.rec_strength * self.rec(s_prev)
        v_t = self.decay * v_prev + I_t + R_t - self.v_threshold * s_prev
        s_t = surrogate_heaviside(v_t - self.v_threshold, alpha=self.spike_alpha)
        return v_t, s_t


class SimpleSNN(nn.Module):
    """
    Minimal SNN for SHD-like input (B,T,D):
      - One LIF hidden layer
      - Readout linear on time-summed spikes
    """
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, v_threshold: float = 1.0, decay: float = 0.95, spike_alpha: float = 5.0, rec_strength: float = 0.0, rec_init_scale: float = 1.0):
        super().__init__()
        self.lif = LIFLayer(input_dim, hidden_dim, v_threshold=v_threshold, decay=decay, spike_alpha=spike_alpha, rec_strength=rec_strength, rec_init_scale=rec_init_scale)
        self.readout = nn.Linear(hidden_dim, num_classes)
        nn.init.xavier_uniform_(self.readout.weight)
        nn.init.zeros_(self.readout.bias)

    @torch.no_grad()
    def _init_states(self, batch_size: int, hidden_dim: int, device, dtype):
        v0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
        s0 = torch.zeros(batch_size, hidden_dim, device=device, dtype=dtype)
        return v0, s0

    def forward(
        self,
        x: torch.Tensor,
        compute_lyapunov: bool = False,
        lyap_eps: float = 1e-3,
        lyap_safe_eps: float = 1e-8,
        lyap_measure: str = "v",  # "v" or "s"
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        x: (B, T, D)
        Returns:
          logits: (B, C)
          lyap_est: scalar tensor if compute_lyapunov else None
        """
        assert x.ndim == 3, f"Expected (B,T,D), got {x.shape}"
        B, T, D = x.shape
        device, dtype = x.device, x.dtype
        H = self.readout.in_features

        v, s = self._init_states(B, H, device, dtype)
        spike_sum = torch.zeros(B, H, device=device, dtype=dtype)

        if compute_lyapunov:
            v_p = v + lyap_eps * torch.randn_like(v)
            s_p = s.clone()
            delta_prev = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
            lyap_terms = []

        for t in range(T):
            x_t = x[:, t, :]
            v, s = self.lif(x_t, v, s)
            spike_sum = spike_sum + s

            if compute_lyapunov:
                # run perturbed trajectory through same ops
                v_p, s_p = self.lif(x_t, v_p, s_p)
                if lyap_measure == "s":
                    delta_t = torch.norm((s_p - s).reshape(B, -1), dim=1) + lyap_safe_eps
                else:
                    delta_t = torch.norm((v_p - v).reshape(B, -1), dim=1) + lyap_safe_eps
                ratio = delta_t / delta_prev
                lyap_terms.append(torch.log(ratio + lyap_safe_eps))
                delta_prev = delta_t

        logits = self.readout(spike_sum)  # (B, C)

        if compute_lyapunov:
            lyap_batch = torch.stack(lyap_terms, dim=0).mean(dim=0)  # (B,)
            lyap_est = lyap_batch.mean()  # scalar
        else:
            lyap_est = None

        return logits, lyap_est