""" Challenging benchmark datasets for deep SNN evaluation. Datasets: 1. Sequential MNIST (sMNIST) - pixel-by-pixel, 784 timesteps 2. Permuted Sequential MNIST (psMNIST) - shuffled pixel order 3. CIFAR-10 with rate coding 4. DVS-CIFAR10 (requires tonic library) These benchmarks are harder than SHD and benefit from deeper networks. """ import os from typing import Optional, Tuple import numpy as np import torch from torch.utils.data import Dataset, DataLoader class SequentialMNIST(Dataset): """ Sequential MNIST - feed pixels one at a time. Each 28x28 image becomes a sequence of 784 timesteps, each with a single pixel intensity converted to spike probability. This is MUCH harder than standard MNIST because: - Network must remember information across 784 timesteps - Tests long-range temporal dependencies - Shallow networks fail due to vanishing gradients Args: root: Data directory train: Train or test split permute: If True, use fixed random permutation (psMNIST) spike_encoding: 'rate' (Poisson) or 'latency' or 'direct' (raw intensity) max_rate: Maximum firing rate for rate coding seed: Random seed for permutation """ def __init__( self, root: str = "./data", train: bool = True, permute: bool = False, spike_encoding: str = "rate", max_rate: float = 100.0, n_repeat: int = 1, # Repeat each pixel n times for more spikes seed: int = 42, download: bool = True, ): try: from torchvision import datasets, transforms except ImportError: raise ImportError("torchvision required: pip install torchvision") self.train = train self.permute = permute self.spike_encoding = spike_encoding self.max_rate = max_rate self.n_repeat = n_repeat # Load MNIST self.mnist = datasets.MNIST( root=root, train=train, download=download, transform=transforms.ToTensor(), ) # Create fixed permutation for psMNIST if permute: rng = np.random.RandomState(seed) self.perm = torch.from_numpy(rng.permutation(784)) else: self.perm = None def __len__(self): return len(self.mnist) def __getitem__(self, idx): img, label = self.mnist[idx] # Flatten to (784,) pixels = img.view(-1) # Apply permutation if self.perm is not None: pixels = pixels[self.perm] # Convert to spike sequence (T, 1) where T = 784 * n_repeat T = 784 * self.n_repeat if self.spike_encoding == "direct": # Direct intensity: repeat each pixel n_repeat times spikes = pixels.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1) elif self.spike_encoding == "rate": # Rate coding: Poisson spikes based on intensity probs = pixels * (self.max_rate / 1000.0) # Assuming 1ms bins probs = probs.clamp(0, 1) # Repeat and sample probs_expanded = probs.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1) spikes = torch.bernoulli(probs_expanded) elif self.spike_encoding == "latency": # Latency coding: spike time proportional to intensity # High intensity = early spike, low = late spike spikes = torch.zeros(T, 1) for i, p in enumerate(pixels): if p > 0.1: # Threshold for spiking # Spike time: higher intensity = earlier spike_time = int((1 - p) * (self.n_repeat - 1)) t = i * self.n_repeat + spike_time spikes[t, 0] = 1.0 return spikes, label class RateCodingCIFAR10(Dataset): """ CIFAR-10 with rate coding for SNNs. Converts 32x32x3 images to spike trains: - Each pixel channel becomes a Poisson spike train - Total input dimension: 32*32*3 = 3072 - Sequence length: T timesteps Args: root: Data directory train: Train or test split T: Number of timesteps max_rate: Maximum firing rate (Hz) flatten: If True, flatten spatial dimensions """ def __init__( self, root: str = "./data", train: bool = True, T: int = 100, max_rate: float = 200.0, flatten: bool = True, download: bool = True, ): try: from torchvision import datasets, transforms except ImportError: raise ImportError("torchvision required: pip install torchvision") self.T = T self.max_rate = max_rate self.flatten = flatten # Normalize to [0, 1] transform = transforms.Compose([ transforms.ToTensor(), ]) self.cifar = datasets.CIFAR10( root=root, train=train, download=download, transform=transform, ) def __len__(self): return len(self.cifar) def __getitem__(self, idx): img, label = self.cifar[idx] # (3, 32, 32) if self.flatten: img = img.view(-1) # (3072,) # Rate coding prob_per_step = img * (self.max_rate / 1000.0) # Assuming 1ms steps prob_per_step = prob_per_step.clamp(0, 1) # Generate spikes for T timesteps if self.flatten: probs = prob_per_step.unsqueeze(0).expand(self.T, -1) # (T, 3072) else: probs = prob_per_step.unsqueeze(0).expand(self.T, -1, -1, -1) # (T, 3, 32, 32) spikes = torch.bernoulli(probs) return spikes, label class DVSCIFAR10(Dataset): """ DVS-CIFAR10 dataset wrapper. Requires the 'tonic' library for neuromorphic datasets: pip install tonic DVS-CIFAR10 is recorded from a Dynamic Vision Sensor watching CIFAR-10 images on a monitor. It's a standard neuromorphic benchmark. Args: root: Data directory train: Train or test split T: Number of time bins spatial_downsample: Downsample spatial resolution """ def __init__( self, root: str = "./data", train: bool = True, T: int = 100, dt_ms: float = 10.0, download: bool = True, ): try: import tonic from tonic import transforms as tonic_transforms except ImportError: raise ImportError( "tonic library required for DVS datasets: pip install tonic" ) self.T = T # Time binning transform sensor_size = tonic.datasets.CIFAR10DVS.sensor_size frame_transform = tonic_transforms.ToFrame( sensor_size=sensor_size, time_window=dt_ms * 1000, # Convert to microseconds ) self.dataset = tonic.datasets.CIFAR10DVS( save_to=root, train=train, transform=frame_transform, ) def __len__(self): return len(self.dataset) def __getitem__(self, idx): frames, label = self.dataset[idx] # (T', 2, H, W) - 2 polarities # Convert to tensor and flatten spatial dims frames = torch.from_numpy(frames).float() # Adjust to target T T_actual = frames.shape[0] if T_actual > self.T: # Subsample indices = torch.linspace(0, T_actual - 1, self.T).long() frames = frames[indices] elif T_actual < self.T: # Pad with zeros pad = torch.zeros(self.T - T_actual, *frames.shape[1:]) frames = torch.cat([frames, pad], dim=0) # Flatten: (T, 2, H, W) -> (T, 2*H*W) frames = frames.view(self.T, -1) return frames, label def get_benchmark_dataloader( dataset_name: str, batch_size: int = 64, root: str = "./data", **kwargs, ) -> Tuple[DataLoader, DataLoader, dict]: """ Get train and validation dataloaders for a benchmark dataset. Args: dataset_name: One of 'smnist', 'psmnist', 'cifar10', 'dvs_cifar10' batch_size: Batch size root: Data directory **kwargs: Additional arguments passed to dataset Returns: train_loader, val_loader, info_dict """ if dataset_name == "smnist": train_ds = SequentialMNIST(root, train=True, permute=False, **kwargs) val_ds = SequentialMNIST(root, train=False, permute=False, **kwargs) info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10, "description": "Sequential MNIST - 784 timesteps, 1 pixel at a time"} elif dataset_name == "psmnist": train_ds = SequentialMNIST(root, train=True, permute=True, **kwargs) val_ds = SequentialMNIST(root, train=False, permute=True, **kwargs) info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10, "description": "Permuted Sequential MNIST - shuffled pixel order, tests long-range memory"} elif dataset_name == "cifar10": T = kwargs.pop("T", 100) train_ds = RateCodingCIFAR10(root, train=True, T=T, **kwargs) val_ds = RateCodingCIFAR10(root, train=False, T=T, **kwargs) info = {"T": T, "D": 3072, "classes": 10, "description": "CIFAR-10 with rate coding"} elif dataset_name == "dvs_cifar10": train_ds = DVSCIFAR10(root, train=True, **kwargs) val_ds = DVSCIFAR10(root, train=False, **kwargs) info = {"T": kwargs.get("T", 100), "D": 2 * 128 * 128, "classes": 10, "description": "DVS-CIFAR10 neuromorphic dataset"} else: raise ValueError(f"Unknown dataset: {dataset_name}. " f"Options: smnist, psmnist, cifar10, dvs_cifar10") train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4) return train_loader, val_loader, info # Quick test if __name__ == "__main__": print("Testing benchmark datasets...\n") # Test sMNIST print("1. Sequential MNIST") try: train_loader, val_loader, info = get_benchmark_dataloader( "smnist", batch_size=32, n_repeat=1, spike_encoding="direct" ) x, y = next(iter(train_loader)) print(f" Shape: {x.shape}, Labels: {y.shape}") print(f" Info: {info}") except Exception as e: print(f" Error: {e}") # Test psMNIST print("\n2. Permuted Sequential MNIST") try: train_loader, val_loader, info = get_benchmark_dataloader( "psmnist", batch_size=32, n_repeat=1, spike_encoding="direct" ) x, y = next(iter(train_loader)) print(f" Shape: {x.shape}, Labels: {y.shape}") print(f" Info: {info}") except Exception as e: print(f" Error: {e}") # Test CIFAR-10 print("\n3. CIFAR-10 (rate coded)") try: train_loader, val_loader, info = get_benchmark_dataloader( "cifar10", batch_size=32, T=50 ) x, y = next(iter(train_loader)) print(f" Shape: {x.shape}, Labels: {y.shape}") print(f" Info: {info}") except Exception as e: print(f" Error: {e}") print("\nDone!")