summaryrefslogtreecommitdiff
path: root/files/data_io/transforms
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/data_io/transforms
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files/data_io/transforms')
-rw-r--r--files/data_io/transforms/__init__.py1
-rw-r--r--files/data_io/transforms/normalization.py52
-rw-r--r--files/data_io/transforms/spike_augmentation.py10
3 files changed, 63 insertions, 0 deletions
diff --git a/files/data_io/transforms/__init__.py b/files/data_io/transforms/__init__.py
new file mode 100644
index 0000000..f6aa496
--- /dev/null
+++ b/files/data_io/transforms/__init__.py
@@ -0,0 +1 @@
+"""Transforms for spike data (normalization, augmentation, etc.)."""
diff --git a/files/data_io/transforms/normalization.py b/files/data_io/transforms/normalization.py
new file mode 100644
index 0000000..c86b8c7
--- /dev/null
+++ b/files/data_io/transforms/normalization.py
@@ -0,0 +1,52 @@
+import torch
+
+class Normalize:
+ """Normalize spike input (z-score or min-max)."""
+ def __init__(self, mode: str = "zscore", eps: float = 1e-6):
+ """
+ Parameters
+ ----------
+ mode : str
+ One of {"zscore", "minmax"}.
+ - "zscore": per-sample, per-channel standardization across time.
+ - "minmax": per-sample, per-channel min-max scaling across time.
+ eps : float
+ Small constant to avoid division by zero.
+ """
+ mode_l = str(mode).lower()
+ if mode_l not in {"zscore", "minmax"}:
+ raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}")
+ self.mode = mode_l
+ self.eps = float(eps)
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Apply normalization.
+ Accepts tensors of shape (T, D) or (B, T, D).
+ Normalization is computed per-sample and per-channel over the time axis T.
+ """
+ if not isinstance(x, torch.Tensor):
+ raise TypeError("Normalize expects a torch.Tensor")
+ if x.ndim == 2:
+ # (T, D) -> time dim = 0
+ time_dim = 0
+ keep_dims = True
+ elif x.ndim == 3:
+ # (B, T, D) -> time dim = 1
+ time_dim = 1
+ keep_dims = True
+ else:
+ raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}")
+
+ if self.mode == "zscore":
+ mean_t = x.mean(dim=time_dim, keepdim=keep_dims)
+ # population std (unbiased=False) to avoid NaNs for small T
+ std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False)
+ x_norm = (x - mean_t) / (std_t + self.eps)
+ return x_norm
+ else: # "minmax"
+ min_t = x.amin(dim=time_dim, keepdim=keep_dims)
+ max_t = x.amax(dim=time_dim, keepdim=keep_dims)
+ denom = (max_t - min_t).clamp_min(self.eps)
+ x_scaled = (x - min_t) / denom
+ return x_scaled
diff --git a/files/data_io/transforms/spike_augmentation.py b/files/data_io/transforms/spike_augmentation.py
new file mode 100644
index 0000000..9b7b687
--- /dev/null
+++ b/files/data_io/transforms/spike_augmentation.py
@@ -0,0 +1,10 @@
+import torch
+
+class SpikeJitter:
+ """Add temporal jitter noise to spikes."""
+ def __init__(self, std=0.01):
+ self.std = std
+
+ def __call__(self, spikes: torch.Tensor) -> torch.Tensor:
+ # TODO: add random jitter to spike timings
+ return spikes