diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/data_io/encoders/poisson_encoder.py | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files/data_io/encoders/poisson_encoder.py')
| -rw-r--r-- | files/data_io/encoders/poisson_encoder.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/files/data_io/encoders/poisson_encoder.py b/files/data_io/encoders/poisson_encoder.py new file mode 100644 index 0000000..0b404f7 --- /dev/null +++ b/files/data_io/encoders/poisson_encoder.py @@ -0,0 +1,91 @@ +from typing import Optional, Union +import numpy as np +import torch +from .base_encoder import BaseEncoder + +class PoissonEncoder(BaseEncoder): + r""" + PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling. + + Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D} + by sampling at each time step t and dimension d: + S[t, d] ~ Bernoulli( p[d] ), where p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ). + + Parameters + ---------- + max_rate : float + Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20. + T : int + Number of discrete time steps in the encoded spike train. Default: 50. + dt_ms : float + Time resolution per step in milliseconds. Default: 1.0 ms. + Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability. + seed : int or None + Optional RNG seed for reproducibility. If None, uses global RNG state. + + Notes + ----- + - Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]). + If 1D, we return S with shape (T, D). + If 2D (batched), we broadcast probabilities across batch and return (T, B, D). + - Intensities outside [0,1] will be clipped to [0,1]. + - Device of returned tensor follows torch default device (CPU) unless you move it later. + """ + + def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None): + super().__init__() + self.max_rate = float(max_rate) + self.T = int(T) + self.dt_ms = float(dt_ms) + self.seed = seed + # local generator for reproducibility if seed is provided + self._g = torch.Generator() + if seed is not None: + self._g.manual_seed(int(seed)) + + def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + if isinstance(data, torch.Tensor): + data = data.detach().cpu().numpy() + return np.asarray(data) + + def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + """ + Convert input intensities to Poisson spike trains. + + Parameters + ---------- + data : np.ndarray or torch.Tensor + Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not). + + Returns + ------- + spikes : torch.Tensor + Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}. + """ + x = self._ensure_numpy(data) + # clip to [0,1] + x = np.clip(x, 0.0, 1.0) + + # compute per-step probability + p_step = float(self.max_rate) * (self.dt_ms / 1000.0) # probability per step for unit intensity + # probability tensor (broadcast-friendly) + probs = x * p_step + probs = np.clip(probs, 0.0, 1.0) + + # create Bernoulli samples for T steps + # If input is 1D: probs.shape = (D,) -> output (T, D) + # If input is 2D: probs.shape = (B, D) -> output (T, B, D) + if probs.ndim == 1: + D = probs.shape[0] + probs_t = np.broadcast_to(probs, (self.T, D)) # (T, D) + probs_t = torch.from_numpy(probs_t.astype(np.float32)) + spikes = torch.bernoulli(probs_t, generator=self._g) + return spikes + elif probs.ndim == 2: + B, D = probs.shape + probs_t = np.broadcast_to(probs, (self.T, B, D)) # (T, B, D) + probs_t = torch.from_numpy(probs_t.astype(np.float32)) + spikes = torch.bernoulli(probs_t, generator=self._g) + return spikes + else: + raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}") |
