summaryrefslogtreecommitdiff
path: root/files/data_io/utils/visualize.py
blob: 6f0de9506338857450ad475bd0cd8965214a7af4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import matplotlib.pyplot as plt
import torch

def plot_raster(spikes: torch.Tensor, title=None):
    """
    Plot raster diagram of spike activity (T,B,N) or (T,N).
    """
    s = spikes.detach().cpu()
    if s.ndim == 3:
        s = s[:, 0, :]  # take first batch
    t, n = s.shape
    for i in range(n):
        times = torch.nonzero(s[:, i]).squeeze().numpy()
        plt.scatter(times, i * np.ones_like(times), s=2, c='black')
    plt.xlabel("Time step")
    plt.ylabel("Neuron index")
    if title:
        plt.title(title)
    plt.show()