from typing import Optional, Union import numpy as np import torch from .base_encoder import BaseEncoder class PoissonEncoder(BaseEncoder): r""" PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling. Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D} by sampling at each time step t and dimension d: S[t, d] ~ Bernoulli( p[d] ), where p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ). Parameters ---------- max_rate : float Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20. T : int Number of discrete time steps in the encoded spike train. Default: 50. dt_ms : float Time resolution per step in milliseconds. Default: 1.0 ms. Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability. seed : int or None Optional RNG seed for reproducibility. If None, uses global RNG state. Notes ----- - Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]). If 1D, we return S with shape (T, D). If 2D (batched), we broadcast probabilities across batch and return (T, B, D). - Intensities outside [0,1] will be clipped to [0,1]. - Device of returned tensor follows torch default device (CPU) unless you move it later. """ def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None): super().__init__() self.max_rate = float(max_rate) self.T = int(T) self.dt_ms = float(dt_ms) self.seed = seed # local generator for reproducibility if seed is provided self._g = torch.Generator() if seed is not None: self._g.manual_seed(int(seed)) def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() return np.asarray(data) def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: """ Convert input intensities to Poisson spike trains. Parameters ---------- data : np.ndarray or torch.Tensor Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not). Returns ------- spikes : torch.Tensor Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}. """ x = self._ensure_numpy(data) # clip to [0,1] x = np.clip(x, 0.0, 1.0) # compute per-step probability p_step = float(self.max_rate) * (self.dt_ms / 1000.0) # probability per step for unit intensity # probability tensor (broadcast-friendly) probs = x * p_step probs = np.clip(probs, 0.0, 1.0) # create Bernoulli samples for T steps # If input is 1D: probs.shape = (D,) -> output (T, D) # If input is 2D: probs.shape = (B, D) -> output (T, B, D) if probs.ndim == 1: D = probs.shape[0] probs_t = np.broadcast_to(probs, (self.T, D)) # (T, D) probs_t = torch.from_numpy(probs_t.astype(np.float32)) spikes = torch.bernoulli(probs_t, generator=self._g) return spikes elif probs.ndim == 2: B, D = probs.shape probs_t = np.broadcast_to(probs, (self.T, B, D)) # (T, B, D) probs_t = torch.from_numpy(probs_t.astype(np.float32)) spikes = torch.bernoulli(probs_t, generator=self._g) return spikes else: raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}")