import torch class Normalize: """Normalize spike input (z-score or min-max).""" def __init__(self, mode: str = "zscore", eps: float = 1e-6): """ Parameters ---------- mode : str One of {"zscore", "minmax"}. - "zscore": per-sample, per-channel standardization across time. - "minmax": per-sample, per-channel min-max scaling across time. eps : float Small constant to avoid division by zero. """ mode_l = str(mode).lower() if mode_l not in {"zscore", "minmax"}: raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}") self.mode = mode_l self.eps = float(eps) def __call__(self, x: torch.Tensor) -> torch.Tensor: """ Apply normalization. Accepts tensors of shape (T, D) or (B, T, D). Normalization is computed per-sample and per-channel over the time axis T. """ if not isinstance(x, torch.Tensor): raise TypeError("Normalize expects a torch.Tensor") if x.ndim == 2: # (T, D) -> time dim = 0 time_dim = 0 keep_dims = True elif x.ndim == 3: # (B, T, D) -> time dim = 1 time_dim = 1 keep_dims = True else: raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}") if self.mode == "zscore": mean_t = x.mean(dim=time_dim, keepdim=keep_dims) # population std (unbiased=False) to avoid NaNs for small T std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False) x_norm = (x - mean_t) / (std_t + self.eps) return x_norm else: # "minmax" min_t = x.amin(dim=time_dim, keepdim=keep_dims) max_t = x.amax(dim=time_dim, keepdim=keep_dims) denom = (max_t - min_t).clamp_min(self.eps) x_scaled = (x - min_t) / denom return x_scaled