summaryrefslogtreecommitdiff
path: root/files/data_io/encoders/poisson_encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/data_io/encoders/poisson_encoder.py')
-rw-r--r--files/data_io/encoders/poisson_encoder.py91
1 files changed, 91 insertions, 0 deletions
diff --git a/files/data_io/encoders/poisson_encoder.py b/files/data_io/encoders/poisson_encoder.py
new file mode 100644
index 0000000..0b404f7
--- /dev/null
+++ b/files/data_io/encoders/poisson_encoder.py
@@ -0,0 +1,91 @@
+from typing import Optional, Union
+import numpy as np
+import torch
+from .base_encoder import BaseEncoder
+
+class PoissonEncoder(BaseEncoder):
+ r"""
+ PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling.
+
+ Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D}
+ by sampling at each time step t and dimension d:
+ S[t, d] ~ Bernoulli( p[d] ), where p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ).
+
+ Parameters
+ ----------
+ max_rate : float
+ Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20.
+ T : int
+ Number of discrete time steps in the encoded spike train. Default: 50.
+ dt_ms : float
+ Time resolution per step in milliseconds. Default: 1.0 ms.
+ Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability.
+ seed : int or None
+ Optional RNG seed for reproducibility. If None, uses global RNG state.
+
+ Notes
+ -----
+ - Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]).
+ If 1D, we return S with shape (T, D).
+ If 2D (batched), we broadcast probabilities across batch and return (T, B, D).
+ - Intensities outside [0,1] will be clipped to [0,1].
+ - Device of returned tensor follows torch default device (CPU) unless you move it later.
+ """
+
+ def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None):
+ super().__init__()
+ self.max_rate = float(max_rate)
+ self.T = int(T)
+ self.dt_ms = float(dt_ms)
+ self.seed = seed
+ # local generator for reproducibility if seed is provided
+ self._g = torch.Generator()
+ if seed is not None:
+ self._g.manual_seed(int(seed))
+
+ def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
+ if isinstance(data, torch.Tensor):
+ data = data.detach().cpu().numpy()
+ return np.asarray(data)
+
+ def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ """
+ Convert input intensities to Poisson spike trains.
+
+ Parameters
+ ----------
+ data : np.ndarray or torch.Tensor
+ Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not).
+
+ Returns
+ -------
+ spikes : torch.Tensor
+ Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}.
+ """
+ x = self._ensure_numpy(data)
+ # clip to [0,1]
+ x = np.clip(x, 0.0, 1.0)
+
+ # compute per-step probability
+ p_step = float(self.max_rate) * (self.dt_ms / 1000.0) # probability per step for unit intensity
+ # probability tensor (broadcast-friendly)
+ probs = x * p_step
+ probs = np.clip(probs, 0.0, 1.0)
+
+ # create Bernoulli samples for T steps
+ # If input is 1D: probs.shape = (D,) -> output (T, D)
+ # If input is 2D: probs.shape = (B, D) -> output (T, B, D)
+ if probs.ndim == 1:
+ D = probs.shape[0]
+ probs_t = np.broadcast_to(probs, (self.T, D)) # (T, D)
+ probs_t = torch.from_numpy(probs_t.astype(np.float32))
+ spikes = torch.bernoulli(probs_t, generator=self._g)
+ return spikes
+ elif probs.ndim == 2:
+ B, D = probs.shape
+ probs_t = np.broadcast_to(probs, (self.T, B, D)) # (T, B, D)
+ probs_t = torch.from_numpy(probs_t.astype(np.float32))
+ spikes = torch.bernoulli(probs_t, generator=self._g)
+ return spikes
+ else:
+ raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}")