blob: 6f0de9506338857450ad475bd0cd8965214a7af4 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
import matplotlib.pyplot as plt
import torch
def plot_raster(spikes: torch.Tensor, title=None):
"""
Plot raster diagram of spike activity (T,B,N) or (T,N).
"""
s = spikes.detach().cpu()
if s.ndim == 3:
s = s[:, 0, :] # take first batch
t, n = s.shape
for i in range(n):
times = torch.nonzero(s[:, i]).squeeze().numpy()
plt.scatter(times, i * np.ones_like(times), s=2, c='black')
plt.xlabel("Time step")
plt.ylabel("Neuron index")
if title:
plt.title(title)
plt.show()
|