diff options
Diffstat (limited to 'files/data_io/transforms')
| -rw-r--r-- | files/data_io/transforms/__init__.py | 1 | ||||
| -rw-r--r-- | files/data_io/transforms/normalization.py | 52 | ||||
| -rw-r--r-- | files/data_io/transforms/spike_augmentation.py | 10 |
3 files changed, 63 insertions, 0 deletions
diff --git a/files/data_io/transforms/__init__.py b/files/data_io/transforms/__init__.py new file mode 100644 index 0000000..f6aa496 --- /dev/null +++ b/files/data_io/transforms/__init__.py @@ -0,0 +1 @@ +"""Transforms for spike data (normalization, augmentation, etc.).""" diff --git a/files/data_io/transforms/normalization.py b/files/data_io/transforms/normalization.py new file mode 100644 index 0000000..c86b8c7 --- /dev/null +++ b/files/data_io/transforms/normalization.py @@ -0,0 +1,52 @@ +import torch + +class Normalize: + """Normalize spike input (z-score or min-max).""" + def __init__(self, mode: str = "zscore", eps: float = 1e-6): + """ + Parameters + ---------- + mode : str + One of {"zscore", "minmax"}. + - "zscore": per-sample, per-channel standardization across time. + - "minmax": per-sample, per-channel min-max scaling across time. + eps : float + Small constant to avoid division by zero. + """ + mode_l = str(mode).lower() + if mode_l not in {"zscore", "minmax"}: + raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}") + self.mode = mode_l + self.eps = float(eps) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply normalization. + Accepts tensors of shape (T, D) or (B, T, D). + Normalization is computed per-sample and per-channel over the time axis T. + """ + if not isinstance(x, torch.Tensor): + raise TypeError("Normalize expects a torch.Tensor") + if x.ndim == 2: + # (T, D) -> time dim = 0 + time_dim = 0 + keep_dims = True + elif x.ndim == 3: + # (B, T, D) -> time dim = 1 + time_dim = 1 + keep_dims = True + else: + raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}") + + if self.mode == "zscore": + mean_t = x.mean(dim=time_dim, keepdim=keep_dims) + # population std (unbiased=False) to avoid NaNs for small T + std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False) + x_norm = (x - mean_t) / (std_t + self.eps) + return x_norm + else: # "minmax" + min_t = x.amin(dim=time_dim, keepdim=keep_dims) + max_t = x.amax(dim=time_dim, keepdim=keep_dims) + denom = (max_t - min_t).clamp_min(self.eps) + x_scaled = (x - min_t) / denom + return x_scaled diff --git a/files/data_io/transforms/spike_augmentation.py b/files/data_io/transforms/spike_augmentation.py new file mode 100644 index 0000000..9b7b687 --- /dev/null +++ b/files/data_io/transforms/spike_augmentation.py @@ -0,0 +1,10 @@ +import torch + +class SpikeJitter: + """Add temporal jitter noise to spikes.""" + def __init__(self, std=0.01): + self.std = std + + def __call__(self, spikes: torch.Tensor) -> torch.Tensor: + # TODO: add random jitter to spike timings + return spikes |
