summaryrefslogtreecommitdiff
path: root/files/data_io/benchmark_datasets.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/data_io/benchmark_datasets.py')
-rw-r--r--files/data_io/benchmark_datasets.py360
1 files changed, 360 insertions, 0 deletions
diff --git a/files/data_io/benchmark_datasets.py b/files/data_io/benchmark_datasets.py
new file mode 100644
index 0000000..302ed51
--- /dev/null
+++ b/files/data_io/benchmark_datasets.py
@@ -0,0 +1,360 @@
+"""
+Challenging benchmark datasets for deep SNN evaluation.
+
+Datasets:
+1. Sequential MNIST (sMNIST) - pixel-by-pixel, 784 timesteps
+2. Permuted Sequential MNIST (psMNIST) - shuffled pixel order
+3. CIFAR-10 with rate coding
+4. DVS-CIFAR10 (requires tonic library)
+
+These benchmarks are harder than SHD and benefit from deeper networks.
+"""
+
+import os
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader
+
+
+class SequentialMNIST(Dataset):
+ """
+ Sequential MNIST - feed pixels one at a time.
+
+ Each 28x28 image becomes a sequence of 784 timesteps,
+ each with a single pixel intensity converted to spike probability.
+
+ This is MUCH harder than standard MNIST because:
+ - Network must remember information across 784 timesteps
+ - Tests long-range temporal dependencies
+ - Shallow networks fail due to vanishing gradients
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ permute: If True, use fixed random permutation (psMNIST)
+ spike_encoding: 'rate' (Poisson) or 'latency' or 'direct' (raw intensity)
+ max_rate: Maximum firing rate for rate coding
+ seed: Random seed for permutation
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ permute: bool = False,
+ spike_encoding: str = "rate",
+ max_rate: float = 100.0,
+ n_repeat: int = 1, # Repeat each pixel n times for more spikes
+ seed: int = 42,
+ download: bool = True,
+ ):
+ try:
+ from torchvision import datasets, transforms
+ except ImportError:
+ raise ImportError("torchvision required: pip install torchvision")
+
+ self.train = train
+ self.permute = permute
+ self.spike_encoding = spike_encoding
+ self.max_rate = max_rate
+ self.n_repeat = n_repeat
+
+ # Load MNIST
+ self.mnist = datasets.MNIST(
+ root=root,
+ train=train,
+ download=download,
+ transform=transforms.ToTensor(),
+ )
+
+ # Create fixed permutation for psMNIST
+ if permute:
+ rng = np.random.RandomState(seed)
+ self.perm = torch.from_numpy(rng.permutation(784))
+ else:
+ self.perm = None
+
+ def __len__(self):
+ return len(self.mnist)
+
+ def __getitem__(self, idx):
+ img, label = self.mnist[idx]
+
+ # Flatten to (784,)
+ pixels = img.view(-1)
+
+ # Apply permutation
+ if self.perm is not None:
+ pixels = pixels[self.perm]
+
+ # Convert to spike sequence (T, 1) where T = 784 * n_repeat
+ T = 784 * self.n_repeat
+
+ if self.spike_encoding == "direct":
+ # Direct intensity: repeat each pixel n_repeat times
+ spikes = pixels.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1)
+
+ elif self.spike_encoding == "rate":
+ # Rate coding: Poisson spikes based on intensity
+ probs = pixels * (self.max_rate / 1000.0) # Assuming 1ms bins
+ probs = probs.clamp(0, 1)
+ # Repeat and sample
+ probs_expanded = probs.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1)
+ spikes = torch.bernoulli(probs_expanded)
+
+ elif self.spike_encoding == "latency":
+ # Latency coding: spike time proportional to intensity
+ # High intensity = early spike, low = late spike
+ spikes = torch.zeros(T, 1)
+ for i, p in enumerate(pixels):
+ if p > 0.1: # Threshold for spiking
+ # Spike time: higher intensity = earlier
+ spike_time = int((1 - p) * (self.n_repeat - 1))
+ t = i * self.n_repeat + spike_time
+ spikes[t, 0] = 1.0
+
+ return spikes, label
+
+
+class RateCodingCIFAR10(Dataset):
+ """
+ CIFAR-10 with rate coding for SNNs.
+
+ Converts 32x32x3 images to spike trains:
+ - Each pixel channel becomes a Poisson spike train
+ - Total input dimension: 32*32*3 = 3072
+ - Sequence length: T timesteps
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ T: Number of timesteps
+ max_rate: Maximum firing rate (Hz)
+ flatten: If True, flatten spatial dimensions
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ T: int = 100,
+ max_rate: float = 200.0,
+ flatten: bool = True,
+ download: bool = True,
+ ):
+ try:
+ from torchvision import datasets, transforms
+ except ImportError:
+ raise ImportError("torchvision required: pip install torchvision")
+
+ self.T = T
+ self.max_rate = max_rate
+ self.flatten = flatten
+
+ # Normalize to [0, 1]
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ self.cifar = datasets.CIFAR10(
+ root=root,
+ train=train,
+ download=download,
+ transform=transform,
+ )
+
+ def __len__(self):
+ return len(self.cifar)
+
+ def __getitem__(self, idx):
+ img, label = self.cifar[idx] # (3, 32, 32)
+
+ if self.flatten:
+ img = img.view(-1) # (3072,)
+
+ # Rate coding
+ prob_per_step = img * (self.max_rate / 1000.0) # Assuming 1ms steps
+ prob_per_step = prob_per_step.clamp(0, 1)
+
+ # Generate spikes for T timesteps
+ if self.flatten:
+ probs = prob_per_step.unsqueeze(0).expand(self.T, -1) # (T, 3072)
+ else:
+ probs = prob_per_step.unsqueeze(0).expand(self.T, -1, -1, -1) # (T, 3, 32, 32)
+
+ spikes = torch.bernoulli(probs)
+
+ return spikes, label
+
+
+class DVSCIFAR10(Dataset):
+ """
+ DVS-CIFAR10 dataset wrapper.
+
+ Requires the 'tonic' library for neuromorphic datasets:
+ pip install tonic
+
+ DVS-CIFAR10 is recorded from a Dynamic Vision Sensor watching
+ CIFAR-10 images on a monitor. It's a standard neuromorphic benchmark.
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ T: Number of time bins
+ spatial_downsample: Downsample spatial resolution
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ T: int = 100,
+ dt_ms: float = 10.0,
+ download: bool = True,
+ ):
+ try:
+ import tonic
+ from tonic import transforms as tonic_transforms
+ except ImportError:
+ raise ImportError(
+ "tonic library required for DVS datasets: pip install tonic"
+ )
+
+ self.T = T
+
+ # Time binning transform
+ sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
+ frame_transform = tonic_transforms.ToFrame(
+ sensor_size=sensor_size,
+ time_window=dt_ms * 1000, # Convert to microseconds
+ )
+
+ self.dataset = tonic.datasets.CIFAR10DVS(
+ save_to=root,
+ train=train,
+ transform=frame_transform,
+ )
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ frames, label = self.dataset[idx] # (T', 2, H, W) - 2 polarities
+
+ # Convert to tensor and flatten spatial dims
+ frames = torch.from_numpy(frames).float()
+
+ # Adjust to target T
+ T_actual = frames.shape[0]
+ if T_actual > self.T:
+ # Subsample
+ indices = torch.linspace(0, T_actual - 1, self.T).long()
+ frames = frames[indices]
+ elif T_actual < self.T:
+ # Pad with zeros
+ pad = torch.zeros(self.T - T_actual, *frames.shape[1:])
+ frames = torch.cat([frames, pad], dim=0)
+
+ # Flatten: (T, 2, H, W) -> (T, 2*H*W)
+ frames = frames.view(self.T, -1)
+
+ return frames, label
+
+
+def get_benchmark_dataloader(
+ dataset_name: str,
+ batch_size: int = 64,
+ root: str = "./data",
+ **kwargs,
+) -> Tuple[DataLoader, DataLoader, dict]:
+ """
+ Get train and validation dataloaders for a benchmark dataset.
+
+ Args:
+ dataset_name: One of 'smnist', 'psmnist', 'cifar10', 'dvs_cifar10'
+ batch_size: Batch size
+ root: Data directory
+ **kwargs: Additional arguments passed to dataset
+
+ Returns:
+ train_loader, val_loader, info_dict
+ """
+
+ if dataset_name == "smnist":
+ train_ds = SequentialMNIST(root, train=True, permute=False, **kwargs)
+ val_ds = SequentialMNIST(root, train=False, permute=False, **kwargs)
+ info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10,
+ "description": "Sequential MNIST - 784 timesteps, 1 pixel at a time"}
+
+ elif dataset_name == "psmnist":
+ train_ds = SequentialMNIST(root, train=True, permute=True, **kwargs)
+ val_ds = SequentialMNIST(root, train=False, permute=True, **kwargs)
+ info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10,
+ "description": "Permuted Sequential MNIST - shuffled pixel order, tests long-range memory"}
+
+ elif dataset_name == "cifar10":
+ T = kwargs.pop("T", 100)
+ train_ds = RateCodingCIFAR10(root, train=True, T=T, **kwargs)
+ val_ds = RateCodingCIFAR10(root, train=False, T=T, **kwargs)
+ info = {"T": T, "D": 3072, "classes": 10,
+ "description": "CIFAR-10 with rate coding"}
+
+ elif dataset_name == "dvs_cifar10":
+ train_ds = DVSCIFAR10(root, train=True, **kwargs)
+ val_ds = DVSCIFAR10(root, train=False, **kwargs)
+ info = {"T": kwargs.get("T", 100), "D": 2 * 128 * 128, "classes": 10,
+ "description": "DVS-CIFAR10 neuromorphic dataset"}
+
+ else:
+ raise ValueError(f"Unknown dataset: {dataset_name}. "
+ f"Options: smnist, psmnist, cifar10, dvs_cifar10")
+
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)
+
+ return train_loader, val_loader, info
+
+
+# Quick test
+if __name__ == "__main__":
+ print("Testing benchmark datasets...\n")
+
+ # Test sMNIST
+ print("1. Sequential MNIST")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "smnist", batch_size=32, n_repeat=1, spike_encoding="direct"
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ # Test psMNIST
+ print("\n2. Permuted Sequential MNIST")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "psmnist", batch_size=32, n_repeat=1, spike_encoding="direct"
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ # Test CIFAR-10
+ print("\n3. CIFAR-10 (rate coded)")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "cifar10", batch_size=32, T=50
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ print("\nDone!")