1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
|
from typing import Optional, Union
import numpy as np
import torch
from .base_encoder import BaseEncoder
class PoissonEncoder(BaseEncoder):
r"""
PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling.
Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D}
by sampling at each time step t and dimension d:
S[t, d] ~ Bernoulli( p[d] ), where p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ).
Parameters
----------
max_rate : float
Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20.
T : int
Number of discrete time steps in the encoded spike train. Default: 50.
dt_ms : float
Time resolution per step in milliseconds. Default: 1.0 ms.
Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability.
seed : int or None
Optional RNG seed for reproducibility. If None, uses global RNG state.
Notes
-----
- Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]).
If 1D, we return S with shape (T, D).
If 2D (batched), we broadcast probabilities across batch and return (T, B, D).
- Intensities outside [0,1] will be clipped to [0,1].
- Device of returned tensor follows torch default device (CPU) unless you move it later.
"""
def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None):
super().__init__()
self.max_rate = float(max_rate)
self.T = int(T)
self.dt_ms = float(dt_ms)
self.seed = seed
# local generator for reproducibility if seed is provided
self._g = torch.Generator()
if seed is not None:
self._g.manual_seed(int(seed))
def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
return np.asarray(data)
def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""
Convert input intensities to Poisson spike trains.
Parameters
----------
data : np.ndarray or torch.Tensor
Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not).
Returns
-------
spikes : torch.Tensor
Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}.
"""
x = self._ensure_numpy(data)
# clip to [0,1]
x = np.clip(x, 0.0, 1.0)
# compute per-step probability
p_step = float(self.max_rate) * (self.dt_ms / 1000.0) # probability per step for unit intensity
# probability tensor (broadcast-friendly)
probs = x * p_step
probs = np.clip(probs, 0.0, 1.0)
# create Bernoulli samples for T steps
# If input is 1D: probs.shape = (D,) -> output (T, D)
# If input is 2D: probs.shape = (B, D) -> output (T, B, D)
if probs.ndim == 1:
D = probs.shape[0]
probs_t = np.broadcast_to(probs, (self.T, D)) # (T, D)
probs_t = torch.from_numpy(probs_t.astype(np.float32))
spikes = torch.bernoulli(probs_t, generator=self._g)
return spikes
elif probs.ndim == 2:
B, D = probs.shape
probs_t = np.broadcast_to(probs, (self.T, B, D)) # (T, B, D)
probs_t = torch.from_numpy(probs_t.astype(np.float32))
spikes = torch.bernoulli(probs_t, generator=self._g)
return spikes
else:
raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}")
|