summaryrefslogtreecommitdiff
path: root/files/data_io/transforms/spike_augmentation.py
blob: 9b7b6872cf09bed9248982118a2b9ed4ad37b4b0 (plain)
1
2
3
4
5
6
7
8
9
10
import torch

class SpikeJitter:
    """Add temporal jitter noise to spikes."""
    def __init__(self, std=0.01):
        self.std = std

    def __call__(self, spikes: torch.Tensor) -> torch.Tensor:
        # TODO: add random jitter to spike timings
        return spikes