summaryrefslogtreecommitdiff
path: root/files/data_io/transforms/normalization.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/data_io/transforms/normalization.py')
-rw-r--r--files/data_io/transforms/normalization.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/files/data_io/transforms/normalization.py b/files/data_io/transforms/normalization.py
new file mode 100644
index 0000000..c86b8c7
--- /dev/null
+++ b/files/data_io/transforms/normalization.py
@@ -0,0 +1,52 @@
+import torch
+
+class Normalize:
+ """Normalize spike input (z-score or min-max)."""
+ def __init__(self, mode: str = "zscore", eps: float = 1e-6):
+ """
+ Parameters
+ ----------
+ mode : str
+ One of {"zscore", "minmax"}.
+ - "zscore": per-sample, per-channel standardization across time.
+ - "minmax": per-sample, per-channel min-max scaling across time.
+ eps : float
+ Small constant to avoid division by zero.
+ """
+ mode_l = str(mode).lower()
+ if mode_l not in {"zscore", "minmax"}:
+ raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}")
+ self.mode = mode_l
+ self.eps = float(eps)
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Apply normalization.
+ Accepts tensors of shape (T, D) or (B, T, D).
+ Normalization is computed per-sample and per-channel over the time axis T.
+ """
+ if not isinstance(x, torch.Tensor):
+ raise TypeError("Normalize expects a torch.Tensor")
+ if x.ndim == 2:
+ # (T, D) -> time dim = 0
+ time_dim = 0
+ keep_dims = True
+ elif x.ndim == 3:
+ # (B, T, D) -> time dim = 1
+ time_dim = 1
+ keep_dims = True
+ else:
+ raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}")
+
+ if self.mode == "zscore":
+ mean_t = x.mean(dim=time_dim, keepdim=keep_dims)
+ # population std (unbiased=False) to avoid NaNs for small T
+ std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False)
+ x_norm = (x - mean_t) / (std_t + self.eps)
+ return x_norm
+ else: # "minmax"
+ min_t = x.amin(dim=time_dim, keepdim=keep_dims)
+ max_t = x.amax(dim=time_dim, keepdim=keep_dims)
+ denom = (max_t - min_t).clamp_min(self.eps)
+ x_scaled = (x - min_t) / denom
+ return x_scaled