summaryrefslogtreecommitdiff
path: root/files/data_io/transforms/normalization.py
blob: c86b8c7d7520397d4e5c419c93733ea32abfd5b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch

class Normalize:
    """Normalize spike input (z-score or min-max)."""
    def __init__(self, mode: str = "zscore", eps: float = 1e-6):
        """
        Parameters
        ----------
        mode : str
            One of {"zscore", "minmax"}.
            - "zscore": per-sample, per-channel standardization across time.
            - "minmax": per-sample, per-channel min-max scaling across time.
        eps : float
            Small constant to avoid division by zero.
        """
        mode_l = str(mode).lower()
        if mode_l not in {"zscore", "minmax"}:
            raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}")
        self.mode = mode_l
        self.eps = float(eps)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply normalization.
        Accepts tensors of shape (T, D) or (B, T, D).
        Normalization is computed per-sample and per-channel over the time axis T.
        """
        if not isinstance(x, torch.Tensor):
            raise TypeError("Normalize expects a torch.Tensor")
        if x.ndim == 2:
            # (T, D) -> time dim = 0
            time_dim = 0
            keep_dims = True
        elif x.ndim == 3:
            # (B, T, D) -> time dim = 1
            time_dim = 1
            keep_dims = True
        else:
            raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}")

        if self.mode == "zscore":
            mean_t = x.mean(dim=time_dim, keepdim=keep_dims)
            # population std (unbiased=False) to avoid NaNs for small T
            std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False)
            x_norm = (x - mean_t) / (std_t + self.eps)
            return x_norm
        else:  # "minmax"
            min_t = x.amin(dim=time_dim, keepdim=keep_dims)
            max_t = x.amax(dim=time_dim, keepdim=keep_dims)
            denom = (max_t - min_t).clamp_min(self.eps)
            x_scaled = (x - min_t) / denom
            return x_scaled