summaryrefslogtreecommitdiff
path: root/files/data_io/encoders/poisson_encoder.py
blob: 0b404f70d3ca75d49a4a6dc8f70d914f80e2d3e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Optional, Union
import numpy as np
import torch
from .base_encoder import BaseEncoder

class PoissonEncoder(BaseEncoder):
    r"""
    PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling.

    Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D}
    by sampling at each time step t and dimension d:
        S[t, d] ~ Bernoulli( p[d] ),  where  p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ).

    Parameters
    ----------
    max_rate : float
        Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20.
    T : int
        Number of discrete time steps in the encoded spike train. Default: 50.
    dt_ms : float
        Time resolution per step in milliseconds. Default: 1.0 ms.
        Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability.
    seed : int or None
        Optional RNG seed for reproducibility. If None, uses global RNG state.

    Notes
    -----
    - Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]).
      If 1D, we return S with shape (T, D).
      If 2D (batched), we broadcast probabilities across batch and return (T, B, D).
    - Intensities outside [0,1] will be clipped to [0,1].
    - Device of returned tensor follows torch default device (CPU) unless you move it later.
    """

    def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None):
        super().__init__()
        self.max_rate = float(max_rate)
        self.T = int(T)
        self.dt_ms = float(dt_ms)
        self.seed = seed
        # local generator for reproducibility if seed is provided
        self._g = torch.Generator()
        if seed is not None:
            self._g.manual_seed(int(seed))

    def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
        if isinstance(data, torch.Tensor):
            data = data.detach().cpu().numpy()
        return np.asarray(data)

    def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
        """
        Convert input intensities to Poisson spike trains.

        Parameters
        ----------
        data : np.ndarray or torch.Tensor
            Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not).

        Returns
        -------
        spikes : torch.Tensor
            Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}.
        """
        x = self._ensure_numpy(data)
        # clip to [0,1]
        x = np.clip(x, 0.0, 1.0)

        # compute per-step probability
        p_step = float(self.max_rate) * (self.dt_ms / 1000.0)  # probability per step for unit intensity
        # probability tensor (broadcast-friendly)
        probs = x * p_step
        probs = np.clip(probs, 0.0, 1.0)

        # create Bernoulli samples for T steps
        # If input is 1D: probs.shape = (D,) -> output (T, D)
        # If input is 2D: probs.shape = (B, D) -> output (T, B, D)
        if probs.ndim == 1:
            D = probs.shape[0]
            probs_t = np.broadcast_to(probs, (self.T, D))  # (T, D)
            probs_t = torch.from_numpy(probs_t.astype(np.float32))
            spikes = torch.bernoulli(probs_t, generator=self._g)
            return spikes
        elif probs.ndim == 2:
            B, D = probs.shape
            probs_t = np.broadcast_to(probs, (self.T, B, D))  # (T, B, D)
            probs_t = torch.from_numpy(probs_t.astype(np.float32))
            spikes = torch.bernoulli(probs_t, generator=self._g)
            return spikes
        else:
            raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}")