diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:50:59 -0600 |
| commit | 00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch) | |
| tree | 77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/data_io/utils | |
| parent | c53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff) | |
| parent | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff) | |
Merge master into main
Diffstat (limited to 'files/data_io/utils')
| -rw-r--r-- | files/data_io/utils/__init__.py | 1 | ||||
| -rw-r--r-- | files/data_io/utils/file_utils.py | 15 | ||||
| -rw-r--r-- | files/data_io/utils/spike_tools.py | 10 | ||||
| -rw-r--r-- | files/data_io/utils/visualize.py | 19 |
4 files changed, 45 insertions, 0 deletions
diff --git a/files/data_io/utils/__init__.py b/files/data_io/utils/__init__.py new file mode 100644 index 0000000..ee3ab2f --- /dev/null +++ b/files/data_io/utils/__init__.py @@ -0,0 +1 @@ +"""Utility functions for file management, spike tools, and visualization.""" diff --git a/files/data_io/utils/file_utils.py b/files/data_io/utils/file_utils.py new file mode 100644 index 0000000..0a1e846 --- /dev/null +++ b/files/data_io/utils/file_utils.py @@ -0,0 +1,15 @@ +import os + +def ensure_dir(path: str): + """Ensure that a directory exists.""" + if not os.path.exists(path): + os.makedirs(path) + +def list_files(root: str, suffix: str): + """Recursively list files ending with suffix.""" + matches = [] + for dirpath, _, filenames in os.walk(root): + for f in filenames: + if f.endswith(suffix): + matches.append(os.path.join(dirpath, f)) + return matches diff --git a/files/data_io/utils/spike_tools.py b/files/data_io/utils/spike_tools.py new file mode 100644 index 0000000..968ee72 --- /dev/null +++ b/files/data_io/utils/spike_tools.py @@ -0,0 +1,10 @@ +import torch +import numpy as np + +def to_raster(spikes: torch.Tensor) -> np.ndarray: + """Convert spike tensor (T,B,N) to raster array (T,N).""" + return spikes.detach().cpu().numpy().mean(axis=1) + +def firing_rate(spikes: torch.Tensor, dt=1.0): + """Compute firing rate per neuron.""" + return spikes.sum(dim=0) / (spikes.shape[0] * dt) diff --git a/files/data_io/utils/visualize.py b/files/data_io/utils/visualize.py new file mode 100644 index 0000000..6f0de95 --- /dev/null +++ b/files/data_io/utils/visualize.py @@ -0,0 +1,19 @@ +import matplotlib.pyplot as plt +import torch + +def plot_raster(spikes: torch.Tensor, title=None): + """ + Plot raster diagram of spike activity (T,B,N) or (T,N). + """ + s = spikes.detach().cpu() + if s.ndim == 3: + s = s[:, 0, :] # take first batch + t, n = s.shape + for i in range(n): + times = torch.nonzero(s[:, i]).squeeze().numpy() + plt.scatter(times, i * np.ones_like(times), s=2, c='black') + plt.xlabel("Time step") + plt.ylabel("Neuron index") + if title: + plt.title(title) + plt.show() |
