summaryrefslogtreecommitdiff
path: root/files/data_io/utils/spike_tools.py
blob: 968ee72c423c6ccecce5c7d6b1b5546b2d64faf4 (plain)
1
2
3
4
5
6
7
8
9
10
import torch
import numpy as np

def to_raster(spikes: torch.Tensor) -> np.ndarray:
    """Convert spike tensor (T,B,N) to raster array (T,N)."""
    return spikes.detach().cpu().numpy().mean(axis=1)

def firing_rate(spikes: torch.Tensor, dt=1.0):
    """Compute firing rate per neuron."""
    return spikes.sum(dim=0) / (spikes.shape[0] * dt)