diff options
Diffstat (limited to 'files/data_io/benchmark_datasets.py')
| -rw-r--r-- | files/data_io/benchmark_datasets.py | 360 |
1 files changed, 360 insertions, 0 deletions
diff --git a/files/data_io/benchmark_datasets.py b/files/data_io/benchmark_datasets.py new file mode 100644 index 0000000..302ed51 --- /dev/null +++ b/files/data_io/benchmark_datasets.py @@ -0,0 +1,360 @@ +""" +Challenging benchmark datasets for deep SNN evaluation. + +Datasets: +1. Sequential MNIST (sMNIST) - pixel-by-pixel, 784 timesteps +2. Permuted Sequential MNIST (psMNIST) - shuffled pixel order +3. CIFAR-10 with rate coding +4. DVS-CIFAR10 (requires tonic library) + +These benchmarks are harder than SHD and benefit from deeper networks. +""" + +import os +from typing import Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader + + +class SequentialMNIST(Dataset): + """ + Sequential MNIST - feed pixels one at a time. + + Each 28x28 image becomes a sequence of 784 timesteps, + each with a single pixel intensity converted to spike probability. + + This is MUCH harder than standard MNIST because: + - Network must remember information across 784 timesteps + - Tests long-range temporal dependencies + - Shallow networks fail due to vanishing gradients + + Args: + root: Data directory + train: Train or test split + permute: If True, use fixed random permutation (psMNIST) + spike_encoding: 'rate' (Poisson) or 'latency' or 'direct' (raw intensity) + max_rate: Maximum firing rate for rate coding + seed: Random seed for permutation + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + permute: bool = False, + spike_encoding: str = "rate", + max_rate: float = 100.0, + n_repeat: int = 1, # Repeat each pixel n times for more spikes + seed: int = 42, + download: bool = True, + ): + try: + from torchvision import datasets, transforms + except ImportError: + raise ImportError("torchvision required: pip install torchvision") + + self.train = train + self.permute = permute + self.spike_encoding = spike_encoding + self.max_rate = max_rate + self.n_repeat = n_repeat + + # Load MNIST + self.mnist = datasets.MNIST( + root=root, + train=train, + download=download, + transform=transforms.ToTensor(), + ) + + # Create fixed permutation for psMNIST + if permute: + rng = np.random.RandomState(seed) + self.perm = torch.from_numpy(rng.permutation(784)) + else: + self.perm = None + + def __len__(self): + return len(self.mnist) + + def __getitem__(self, idx): + img, label = self.mnist[idx] + + # Flatten to (784,) + pixels = img.view(-1) + + # Apply permutation + if self.perm is not None: + pixels = pixels[self.perm] + + # Convert to spike sequence (T, 1) where T = 784 * n_repeat + T = 784 * self.n_repeat + + if self.spike_encoding == "direct": + # Direct intensity: repeat each pixel n_repeat times + spikes = pixels.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1) + + elif self.spike_encoding == "rate": + # Rate coding: Poisson spikes based on intensity + probs = pixels * (self.max_rate / 1000.0) # Assuming 1ms bins + probs = probs.clamp(0, 1) + # Repeat and sample + probs_expanded = probs.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1) + spikes = torch.bernoulli(probs_expanded) + + elif self.spike_encoding == "latency": + # Latency coding: spike time proportional to intensity + # High intensity = early spike, low = late spike + spikes = torch.zeros(T, 1) + for i, p in enumerate(pixels): + if p > 0.1: # Threshold for spiking + # Spike time: higher intensity = earlier + spike_time = int((1 - p) * (self.n_repeat - 1)) + t = i * self.n_repeat + spike_time + spikes[t, 0] = 1.0 + + return spikes, label + + +class RateCodingCIFAR10(Dataset): + """ + CIFAR-10 with rate coding for SNNs. + + Converts 32x32x3 images to spike trains: + - Each pixel channel becomes a Poisson spike train + - Total input dimension: 32*32*3 = 3072 + - Sequence length: T timesteps + + Args: + root: Data directory + train: Train or test split + T: Number of timesteps + max_rate: Maximum firing rate (Hz) + flatten: If True, flatten spatial dimensions + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + T: int = 100, + max_rate: float = 200.0, + flatten: bool = True, + download: bool = True, + ): + try: + from torchvision import datasets, transforms + except ImportError: + raise ImportError("torchvision required: pip install torchvision") + + self.T = T + self.max_rate = max_rate + self.flatten = flatten + + # Normalize to [0, 1] + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + self.cifar = datasets.CIFAR10( + root=root, + train=train, + download=download, + transform=transform, + ) + + def __len__(self): + return len(self.cifar) + + def __getitem__(self, idx): + img, label = self.cifar[idx] # (3, 32, 32) + + if self.flatten: + img = img.view(-1) # (3072,) + + # Rate coding + prob_per_step = img * (self.max_rate / 1000.0) # Assuming 1ms steps + prob_per_step = prob_per_step.clamp(0, 1) + + # Generate spikes for T timesteps + if self.flatten: + probs = prob_per_step.unsqueeze(0).expand(self.T, -1) # (T, 3072) + else: + probs = prob_per_step.unsqueeze(0).expand(self.T, -1, -1, -1) # (T, 3, 32, 32) + + spikes = torch.bernoulli(probs) + + return spikes, label + + +class DVSCIFAR10(Dataset): + """ + DVS-CIFAR10 dataset wrapper. + + Requires the 'tonic' library for neuromorphic datasets: + pip install tonic + + DVS-CIFAR10 is recorded from a Dynamic Vision Sensor watching + CIFAR-10 images on a monitor. It's a standard neuromorphic benchmark. + + Args: + root: Data directory + train: Train or test split + T: Number of time bins + spatial_downsample: Downsample spatial resolution + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + T: int = 100, + dt_ms: float = 10.0, + download: bool = True, + ): + try: + import tonic + from tonic import transforms as tonic_transforms + except ImportError: + raise ImportError( + "tonic library required for DVS datasets: pip install tonic" + ) + + self.T = T + + # Time binning transform + sensor_size = tonic.datasets.CIFAR10DVS.sensor_size + frame_transform = tonic_transforms.ToFrame( + sensor_size=sensor_size, + time_window=dt_ms * 1000, # Convert to microseconds + ) + + self.dataset = tonic.datasets.CIFAR10DVS( + save_to=root, + train=train, + transform=frame_transform, + ) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + frames, label = self.dataset[idx] # (T', 2, H, W) - 2 polarities + + # Convert to tensor and flatten spatial dims + frames = torch.from_numpy(frames).float() + + # Adjust to target T + T_actual = frames.shape[0] + if T_actual > self.T: + # Subsample + indices = torch.linspace(0, T_actual - 1, self.T).long() + frames = frames[indices] + elif T_actual < self.T: + # Pad with zeros + pad = torch.zeros(self.T - T_actual, *frames.shape[1:]) + frames = torch.cat([frames, pad], dim=0) + + # Flatten: (T, 2, H, W) -> (T, 2*H*W) + frames = frames.view(self.T, -1) + + return frames, label + + +def get_benchmark_dataloader( + dataset_name: str, + batch_size: int = 64, + root: str = "./data", + **kwargs, +) -> Tuple[DataLoader, DataLoader, dict]: + """ + Get train and validation dataloaders for a benchmark dataset. + + Args: + dataset_name: One of 'smnist', 'psmnist', 'cifar10', 'dvs_cifar10' + batch_size: Batch size + root: Data directory + **kwargs: Additional arguments passed to dataset + + Returns: + train_loader, val_loader, info_dict + """ + + if dataset_name == "smnist": + train_ds = SequentialMNIST(root, train=True, permute=False, **kwargs) + val_ds = SequentialMNIST(root, train=False, permute=False, **kwargs) + info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10, + "description": "Sequential MNIST - 784 timesteps, 1 pixel at a time"} + + elif dataset_name == "psmnist": + train_ds = SequentialMNIST(root, train=True, permute=True, **kwargs) + val_ds = SequentialMNIST(root, train=False, permute=True, **kwargs) + info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10, + "description": "Permuted Sequential MNIST - shuffled pixel order, tests long-range memory"} + + elif dataset_name == "cifar10": + T = kwargs.pop("T", 100) + train_ds = RateCodingCIFAR10(root, train=True, T=T, **kwargs) + val_ds = RateCodingCIFAR10(root, train=False, T=T, **kwargs) + info = {"T": T, "D": 3072, "classes": 10, + "description": "CIFAR-10 with rate coding"} + + elif dataset_name == "dvs_cifar10": + train_ds = DVSCIFAR10(root, train=True, **kwargs) + val_ds = DVSCIFAR10(root, train=False, **kwargs) + info = {"T": kwargs.get("T", 100), "D": 2 * 128 * 128, "classes": 10, + "description": "DVS-CIFAR10 neuromorphic dataset"} + + else: + raise ValueError(f"Unknown dataset: {dataset_name}. " + f"Options: smnist, psmnist, cifar10, dvs_cifar10") + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4) + + return train_loader, val_loader, info + + +# Quick test +if __name__ == "__main__": + print("Testing benchmark datasets...\n") + + # Test sMNIST + print("1. Sequential MNIST") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "smnist", batch_size=32, n_repeat=1, spike_encoding="direct" + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + # Test psMNIST + print("\n2. Permuted Sequential MNIST") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "psmnist", batch_size=32, n_repeat=1, spike_encoding="direct" + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + # Test CIFAR-10 + print("\n3. CIFAR-10 (rate coded)") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "cifar10", batch_size=32, T=50 + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + print("\nDone!") |
