import matplotlib.pyplot as plt import torch def plot_raster(spikes: torch.Tensor, title=None): """ Plot raster diagram of spike activity (T,B,N) or (T,N). """ s = spikes.detach().cpu() if s.ndim == 3: s = s[:, 0, :] # take first batch t, n = s.shape for i in range(n): times = torch.nonzero(s[:, i]).squeeze().numpy() plt.scatter(times, i * np.ones_like(times), s=2, c='black') plt.xlabel("Time step") plt.ylabel("Neuron index") if title: plt.title(title) plt.show()