summaryrefslogtreecommitdiff
path: root/files/data_io/utils/visualize.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
commitcd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch)
tree59a233959932ca0e4f12f196275e07fcf443b33f /files/data_io/utils/visualize.py
init commit
Diffstat (limited to 'files/data_io/utils/visualize.py')
-rw-r--r--files/data_io/utils/visualize.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/files/data_io/utils/visualize.py b/files/data_io/utils/visualize.py
new file mode 100644
index 0000000..6f0de95
--- /dev/null
+++ b/files/data_io/utils/visualize.py
@@ -0,0 +1,19 @@
+import matplotlib.pyplot as plt
+import torch
+
+def plot_raster(spikes: torch.Tensor, title=None):
+ """
+ Plot raster diagram of spike activity (T,B,N) or (T,N).
+ """
+ s = spikes.detach().cpu()
+ if s.ndim == 3:
+ s = s[:, 0, :] # take first batch
+ t, n = s.shape
+ for i in range(n):
+ times = torch.nonzero(s[:, i]).squeeze().numpy()
+ plt.scatter(times, i * np.ones_like(times), s=2, c='black')
+ plt.xlabel("Time step")
+ plt.ylabel("Neuron index")
+ if title:
+ plt.title(title)
+ plt.show()