summaryrefslogtreecommitdiff
path: root/files/data_io/utils
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/data_io/utils
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files/data_io/utils')
-rw-r--r--files/data_io/utils/__init__.py1
-rw-r--r--files/data_io/utils/file_utils.py15
-rw-r--r--files/data_io/utils/spike_tools.py10
-rw-r--r--files/data_io/utils/visualize.py19
4 files changed, 45 insertions, 0 deletions
diff --git a/files/data_io/utils/__init__.py b/files/data_io/utils/__init__.py
new file mode 100644
index 0000000..ee3ab2f
--- /dev/null
+++ b/files/data_io/utils/__init__.py
@@ -0,0 +1 @@
+"""Utility functions for file management, spike tools, and visualization."""
diff --git a/files/data_io/utils/file_utils.py b/files/data_io/utils/file_utils.py
new file mode 100644
index 0000000..0a1e846
--- /dev/null
+++ b/files/data_io/utils/file_utils.py
@@ -0,0 +1,15 @@
+import os
+
+def ensure_dir(path: str):
+ """Ensure that a directory exists."""
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def list_files(root: str, suffix: str):
+ """Recursively list files ending with suffix."""
+ matches = []
+ for dirpath, _, filenames in os.walk(root):
+ for f in filenames:
+ if f.endswith(suffix):
+ matches.append(os.path.join(dirpath, f))
+ return matches
diff --git a/files/data_io/utils/spike_tools.py b/files/data_io/utils/spike_tools.py
new file mode 100644
index 0000000..968ee72
--- /dev/null
+++ b/files/data_io/utils/spike_tools.py
@@ -0,0 +1,10 @@
+import torch
+import numpy as np
+
+def to_raster(spikes: torch.Tensor) -> np.ndarray:
+ """Convert spike tensor (T,B,N) to raster array (T,N)."""
+ return spikes.detach().cpu().numpy().mean(axis=1)
+
+def firing_rate(spikes: torch.Tensor, dt=1.0):
+ """Compute firing rate per neuron."""
+ return spikes.sum(dim=0) / (spikes.shape[0] * dt)
diff --git a/files/data_io/utils/visualize.py b/files/data_io/utils/visualize.py
new file mode 100644
index 0000000..6f0de95
--- /dev/null
+++ b/files/data_io/utils/visualize.py
@@ -0,0 +1,19 @@
+import matplotlib.pyplot as plt
+import torch
+
+def plot_raster(spikes: torch.Tensor, title=None):
+ """
+ Plot raster diagram of spike activity (T,B,N) or (T,N).
+ """
+ s = spikes.detach().cpu()
+ if s.ndim == 3:
+ s = s[:, 0, :] # take first batch
+ t, n = s.shape
+ for i in range(n):
+ times = torch.nonzero(s[:, i]).squeeze().numpy()
+ plt.scatter(times, i * np.ones_like(times), s=2, c='black')
+ plt.xlabel("Time step")
+ plt.ylabel("Neuron index")
+ if title:
+ plt.title(title)
+ plt.show()