diff options
Diffstat (limited to 'files/data_io')
| -rw-r--r-- | files/data_io/__init__.py | 3 | ||||
| -rw-r--r-- | files/data_io/benchmark_datasets.py | 360 | ||||
| -rw-r--r-- | files/data_io/configs/__init__.py | 0 | ||||
| -rw-r--r-- | files/data_io/configs/dvs.yaml | 0 | ||||
| -rw-r--r-- | files/data_io/configs/shd.yaml | 11 | ||||
| -rw-r--r-- | files/data_io/configs/ssc.yaml | 0 | ||||
| -rw-r--r-- | files/data_io/dataset_loader.py | 281 | ||||
| -rw-r--r-- | files/data_io/encoders/__init__.py | 1 | ||||
| -rw-r--r-- | files/data_io/encoders/base_encoder.py | 10 | ||||
| -rw-r--r-- | files/data_io/encoders/latency_encoder.py | 13 | ||||
| -rw-r--r-- | files/data_io/encoders/poisson_encoder.py | 91 | ||||
| -rw-r--r-- | files/data_io/encoders/rank_order_encoder.py | 10 | ||||
| -rw-r--r-- | files/data_io/transforms/__init__.py | 1 | ||||
| -rw-r--r-- | files/data_io/transforms/normalization.py | 52 | ||||
| -rw-r--r-- | files/data_io/transforms/spike_augmentation.py | 10 | ||||
| -rw-r--r-- | files/data_io/utils/__init__.py | 1 | ||||
| -rw-r--r-- | files/data_io/utils/file_utils.py | 15 | ||||
| -rw-r--r-- | files/data_io/utils/spike_tools.py | 10 | ||||
| -rw-r--r-- | files/data_io/utils/visualize.py | 19 |
19 files changed, 888 insertions, 0 deletions
diff --git a/files/data_io/__init__.py b/files/data_io/__init__.py new file mode 100644 index 0000000..de423db --- /dev/null +++ b/files/data_io/__init__.py @@ -0,0 +1,3 @@ +""" +data_io package: unified data loading, encoding, and preprocessing for spiking neural networks. +""" diff --git a/files/data_io/benchmark_datasets.py b/files/data_io/benchmark_datasets.py new file mode 100644 index 0000000..302ed51 --- /dev/null +++ b/files/data_io/benchmark_datasets.py @@ -0,0 +1,360 @@ +""" +Challenging benchmark datasets for deep SNN evaluation. + +Datasets: +1. Sequential MNIST (sMNIST) - pixel-by-pixel, 784 timesteps +2. Permuted Sequential MNIST (psMNIST) - shuffled pixel order +3. CIFAR-10 with rate coding +4. DVS-CIFAR10 (requires tonic library) + +These benchmarks are harder than SHD and benefit from deeper networks. +""" + +import os +from typing import Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader + + +class SequentialMNIST(Dataset): + """ + Sequential MNIST - feed pixels one at a time. + + Each 28x28 image becomes a sequence of 784 timesteps, + each with a single pixel intensity converted to spike probability. + + This is MUCH harder than standard MNIST because: + - Network must remember information across 784 timesteps + - Tests long-range temporal dependencies + - Shallow networks fail due to vanishing gradients + + Args: + root: Data directory + train: Train or test split + permute: If True, use fixed random permutation (psMNIST) + spike_encoding: 'rate' (Poisson) or 'latency' or 'direct' (raw intensity) + max_rate: Maximum firing rate for rate coding + seed: Random seed for permutation + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + permute: bool = False, + spike_encoding: str = "rate", + max_rate: float = 100.0, + n_repeat: int = 1, # Repeat each pixel n times for more spikes + seed: int = 42, + download: bool = True, + ): + try: + from torchvision import datasets, transforms + except ImportError: + raise ImportError("torchvision required: pip install torchvision") + + self.train = train + self.permute = permute + self.spike_encoding = spike_encoding + self.max_rate = max_rate + self.n_repeat = n_repeat + + # Load MNIST + self.mnist = datasets.MNIST( + root=root, + train=train, + download=download, + transform=transforms.ToTensor(), + ) + + # Create fixed permutation for psMNIST + if permute: + rng = np.random.RandomState(seed) + self.perm = torch.from_numpy(rng.permutation(784)) + else: + self.perm = None + + def __len__(self): + return len(self.mnist) + + def __getitem__(self, idx): + img, label = self.mnist[idx] + + # Flatten to (784,) + pixels = img.view(-1) + + # Apply permutation + if self.perm is not None: + pixels = pixels[self.perm] + + # Convert to spike sequence (T, 1) where T = 784 * n_repeat + T = 784 * self.n_repeat + + if self.spike_encoding == "direct": + # Direct intensity: repeat each pixel n_repeat times + spikes = pixels.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1) + + elif self.spike_encoding == "rate": + # Rate coding: Poisson spikes based on intensity + probs = pixels * (self.max_rate / 1000.0) # Assuming 1ms bins + probs = probs.clamp(0, 1) + # Repeat and sample + probs_expanded = probs.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1) + spikes = torch.bernoulli(probs_expanded) + + elif self.spike_encoding == "latency": + # Latency coding: spike time proportional to intensity + # High intensity = early spike, low = late spike + spikes = torch.zeros(T, 1) + for i, p in enumerate(pixels): + if p > 0.1: # Threshold for spiking + # Spike time: higher intensity = earlier + spike_time = int((1 - p) * (self.n_repeat - 1)) + t = i * self.n_repeat + spike_time + spikes[t, 0] = 1.0 + + return spikes, label + + +class RateCodingCIFAR10(Dataset): + """ + CIFAR-10 with rate coding for SNNs. + + Converts 32x32x3 images to spike trains: + - Each pixel channel becomes a Poisson spike train + - Total input dimension: 32*32*3 = 3072 + - Sequence length: T timesteps + + Args: + root: Data directory + train: Train or test split + T: Number of timesteps + max_rate: Maximum firing rate (Hz) + flatten: If True, flatten spatial dimensions + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + T: int = 100, + max_rate: float = 200.0, + flatten: bool = True, + download: bool = True, + ): + try: + from torchvision import datasets, transforms + except ImportError: + raise ImportError("torchvision required: pip install torchvision") + + self.T = T + self.max_rate = max_rate + self.flatten = flatten + + # Normalize to [0, 1] + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + + self.cifar = datasets.CIFAR10( + root=root, + train=train, + download=download, + transform=transform, + ) + + def __len__(self): + return len(self.cifar) + + def __getitem__(self, idx): + img, label = self.cifar[idx] # (3, 32, 32) + + if self.flatten: + img = img.view(-1) # (3072,) + + # Rate coding + prob_per_step = img * (self.max_rate / 1000.0) # Assuming 1ms steps + prob_per_step = prob_per_step.clamp(0, 1) + + # Generate spikes for T timesteps + if self.flatten: + probs = prob_per_step.unsqueeze(0).expand(self.T, -1) # (T, 3072) + else: + probs = prob_per_step.unsqueeze(0).expand(self.T, -1, -1, -1) # (T, 3, 32, 32) + + spikes = torch.bernoulli(probs) + + return spikes, label + + +class DVSCIFAR10(Dataset): + """ + DVS-CIFAR10 dataset wrapper. + + Requires the 'tonic' library for neuromorphic datasets: + pip install tonic + + DVS-CIFAR10 is recorded from a Dynamic Vision Sensor watching + CIFAR-10 images on a monitor. It's a standard neuromorphic benchmark. + + Args: + root: Data directory + train: Train or test split + T: Number of time bins + spatial_downsample: Downsample spatial resolution + """ + + def __init__( + self, + root: str = "./data", + train: bool = True, + T: int = 100, + dt_ms: float = 10.0, + download: bool = True, + ): + try: + import tonic + from tonic import transforms as tonic_transforms + except ImportError: + raise ImportError( + "tonic library required for DVS datasets: pip install tonic" + ) + + self.T = T + + # Time binning transform + sensor_size = tonic.datasets.CIFAR10DVS.sensor_size + frame_transform = tonic_transforms.ToFrame( + sensor_size=sensor_size, + time_window=dt_ms * 1000, # Convert to microseconds + ) + + self.dataset = tonic.datasets.CIFAR10DVS( + save_to=root, + train=train, + transform=frame_transform, + ) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + frames, label = self.dataset[idx] # (T', 2, H, W) - 2 polarities + + # Convert to tensor and flatten spatial dims + frames = torch.from_numpy(frames).float() + + # Adjust to target T + T_actual = frames.shape[0] + if T_actual > self.T: + # Subsample + indices = torch.linspace(0, T_actual - 1, self.T).long() + frames = frames[indices] + elif T_actual < self.T: + # Pad with zeros + pad = torch.zeros(self.T - T_actual, *frames.shape[1:]) + frames = torch.cat([frames, pad], dim=0) + + # Flatten: (T, 2, H, W) -> (T, 2*H*W) + frames = frames.view(self.T, -1) + + return frames, label + + +def get_benchmark_dataloader( + dataset_name: str, + batch_size: int = 64, + root: str = "./data", + **kwargs, +) -> Tuple[DataLoader, DataLoader, dict]: + """ + Get train and validation dataloaders for a benchmark dataset. + + Args: + dataset_name: One of 'smnist', 'psmnist', 'cifar10', 'dvs_cifar10' + batch_size: Batch size + root: Data directory + **kwargs: Additional arguments passed to dataset + + Returns: + train_loader, val_loader, info_dict + """ + + if dataset_name == "smnist": + train_ds = SequentialMNIST(root, train=True, permute=False, **kwargs) + val_ds = SequentialMNIST(root, train=False, permute=False, **kwargs) + info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10, + "description": "Sequential MNIST - 784 timesteps, 1 pixel at a time"} + + elif dataset_name == "psmnist": + train_ds = SequentialMNIST(root, train=True, permute=True, **kwargs) + val_ds = SequentialMNIST(root, train=False, permute=True, **kwargs) + info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10, + "description": "Permuted Sequential MNIST - shuffled pixel order, tests long-range memory"} + + elif dataset_name == "cifar10": + T = kwargs.pop("T", 100) + train_ds = RateCodingCIFAR10(root, train=True, T=T, **kwargs) + val_ds = RateCodingCIFAR10(root, train=False, T=T, **kwargs) + info = {"T": T, "D": 3072, "classes": 10, + "description": "CIFAR-10 with rate coding"} + + elif dataset_name == "dvs_cifar10": + train_ds = DVSCIFAR10(root, train=True, **kwargs) + val_ds = DVSCIFAR10(root, train=False, **kwargs) + info = {"T": kwargs.get("T", 100), "D": 2 * 128 * 128, "classes": 10, + "description": "DVS-CIFAR10 neuromorphic dataset"} + + else: + raise ValueError(f"Unknown dataset: {dataset_name}. " + f"Options: smnist, psmnist, cifar10, dvs_cifar10") + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4) + + return train_loader, val_loader, info + + +# Quick test +if __name__ == "__main__": + print("Testing benchmark datasets...\n") + + # Test sMNIST + print("1. Sequential MNIST") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "smnist", batch_size=32, n_repeat=1, spike_encoding="direct" + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + # Test psMNIST + print("\n2. Permuted Sequential MNIST") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "psmnist", batch_size=32, n_repeat=1, spike_encoding="direct" + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + # Test CIFAR-10 + print("\n3. CIFAR-10 (rate coded)") + try: + train_loader, val_loader, info = get_benchmark_dataloader( + "cifar10", batch_size=32, T=50 + ) + x, y = next(iter(train_loader)) + print(f" Shape: {x.shape}, Labels: {y.shape}") + print(f" Info: {info}") + except Exception as e: + print(f" Error: {e}") + + print("\nDone!") diff --git a/files/data_io/configs/__init__.py b/files/data_io/configs/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/files/data_io/configs/__init__.py diff --git a/files/data_io/configs/dvs.yaml b/files/data_io/configs/dvs.yaml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/files/data_io/configs/dvs.yaml diff --git a/files/data_io/configs/shd.yaml b/files/data_io/configs/shd.yaml new file mode 100644 index 0000000..c786934 --- /dev/null +++ b/files/data_io/configs/shd.yaml @@ -0,0 +1,11 @@ +dataset: SHD +data_dir: /u/yurenh2/ml-projects/snn-training/files/data +shuffle: true +encoder: + type: poisson + max_rate: 50 + dt_ms: 1.0 + seed: 42 # 也可不写,默认 42 +transforms: + normalize: true + spike_jitter: 0.01 diff --git a/files/data_io/configs/ssc.yaml b/files/data_io/configs/ssc.yaml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/files/data_io/configs/ssc.yaml diff --git a/files/data_io/dataset_loader.py b/files/data_io/dataset_loader.py new file mode 100644 index 0000000..983b67b --- /dev/null +++ b/files/data_io/dataset_loader.py @@ -0,0 +1,281 @@ +from typing import Optional +import yaml +import torch +from torch.utils.data import Dataset, DataLoader +import os +import h5py +import numpy as np +from .encoders import poisson_encoder, latency_encoder, rank_order_encoder +from .transforms import normalization, spike_augmentation +# from .utils import file_utils # 目前未使用可先注释 + + + +# ----- +# Base synthetic dataset placeholder +# ----- +class BaseDataset(Dataset): + """ + Abstract base class for datasets. + Each subclass must implement __getitem__ and __len__. + + This base implementation uses synthetic placeholders for quick smoke tests. + """ + + def __init__(self, data_dir, encoder, transforms=None): + self.data_dir = data_dir + self.encoder = encoder + self.transforms = transforms or [] + self.samples = list(range(200)) # placeholder synthetic samples + + def __getitem__(self, idx): + # synthetic static input -> encode to (T, D) via encoder + raw_data = torch.rand(128) # 128-dim intensity + label = torch.randint(0, 10, (1,)).item() + encoded = self.encoder.encode(raw_data.numpy()) if self.encoder is not None else raw_data + x = torch.as_tensor(encoded) + for t in self.transforms: + x = t(x) + return x, label + + def __len__(self): + return len(self.samples) + + +# ----- +# SHD dataset: true event data, read H5 per-sample; fixed global T; adaptive D +# ----- +import h5py # noqa: E402 + + +class SHDDataset(Dataset): + """ + SHD Dataset Loader (.h5) with time-adaptive binning and fixed global T. + + H5 structure (Zenke Lab convention): + - f["labels"][i] -> scalar label for sample i + - f["spikes"]["times"][i] -> 1D array of spike times (ms) for sample i + - f["spikes"]["units"][i] -> 1D array of channel ids for sample i + + We: + (1) scan once to determine global T = ceil(max_time / dt_ms) + (2) decide D from max unit id (fallback to default_D=700) + (3) in __getitem__, open H5, read ragged arrays for that sample, and bin to (T, D) + """ + + def __init__( + self, + data_dir: str, + encoder=None, # ignored for SHD (already spiking events) + transforms=None, + split: str = "train", + dt_ms: float = 1.0, + seed: Optional[int] = None, + default_D: int = 700 + ): + super().__init__() + self.data_dir = data_dir + self.transforms = transforms or [] + self.dt_ms = float(dt_ms) + self.seed = 42 if seed is None else int(seed) + self.encoder = None # IMPORTANT: do not apply intensity encoders to event data + self.default_D = int(default_D) + + fname = f"shd_{split}.h5" + self.path = os.path.join(self.data_dir, fname) + if not os.path.exists(self.path): + raise FileNotFoundError(f"SHD file not found: {self.path}") + + with h5py.File(self.path, "r") as f: + # labels is dense array + self.labels = np.array(f["labels"], dtype=np.int64) + self.N = int(self.labels.shape[0]) + + # ragged datasets for events + times_ds = f["spikes"]["times"] + units_ds = f["spikes"]["units"] + + # scan once to compute global T and adaptive D + t_max_global = 0.0 + max_unit = -1 + for i in range(self.N): + ti = times_ds[i] + ui = units_ds[i] + if ti.size > 0: + last_t = float(ti[-1]) # ms + if last_t > t_max_global: + t_max_global = last_t + if ui.size > 0: + uimax = int(ui.max()) + if uimax > max_unit: + max_unit = uimax + + # decide D + if max_unit >= 0: + self.D = max(max_unit + 1, self.default_D) + else: + self.D = self.default_D + + # decide T from global max time + self.T = int(np.ceil(t_max_global / self.dt_ms)) if t_max_global > 0 else 1 + + # rng in case transforms need it + self._rng = np.random.default_rng(self.seed) + + def __len__(self): + return self.N + + def __getitem__(self, idx: int): + # open file per-sample for worker safety + with h5py.File(self.path, "r") as f: + ti = f["spikes"]["times"][idx][:] + ui = f["spikes"]["units"][idx][:] + y = int(f["labels"][idx]) + + # bin events to (T, D) + spikes = np.zeros((self.T, self.D), dtype=np.float32) + if ti.size > 0: + bins = (ti / self.dt_ms).astype(np.int64) + bins = np.clip(bins, 0, self.T - 1) + ui = np.clip(ui.astype(np.int64), 0, self.D - 1) + spikes[bins, ui] = 1.0 # presence; if you prefer counts: +=1 then clip to 1 + + x = torch.from_numpy(spikes) + + # apply transforms (on torch tensor) + for tr in self.transforms: + x = tr(x) + + return x, y + + +# ----- +# SSC / DVS placeholders (still synthetic; implement real readers later) +# ----- +class SSCDataset(BaseDataset): + """Placeholder SSC dataset (synthetic).""" + pass + + +class DVSDataset(BaseDataset): + """Placeholder DVS dataset (synthetic).""" + pass + + +# ----- +# Helpers: encoders / transforms / cfg path resolution +# ----- +def build_encoder(cfg): + """ + Build encoder from config dict. + + Expected schema: + encoder: + type: poisson | latency | rank_order + # Poisson-only optional fields: + max_rate: 50 + T: 64 + dt_ms: 1.0 + seed: 123 + """ + etype = cfg["type"].lower() + if etype == "poisson": + return poisson_encoder.PoissonEncoder( + max_rate=cfg.get("max_rate", cfg.get("rate", 20)), + T=cfg.get("T", 50), + dt_ms=cfg.get("dt_ms", 1.0), + seed=cfg.get("seed", None), + ) + elif etype == "latency": + return latency_encoder.LatencyEncoder() + elif etype == "rank_order": + return rank_order_encoder.RankOrderEncoder() + else: + raise ValueError(f"Unknown encoder type: {etype}") + + +def build_transforms(cfg): + tlist = [] + if cfg.get("normalize", False): + tlist.append(normalization.Normalize()) + if cfg.get("spike_jitter", None) is not None: + tlist.append(spike_augmentation.SpikeJitter(std=cfg["spike_jitter"])) + return tlist + + +def _resolve_cfg_path(cfg_path: str) -> str: + """ + Resolve cfg_path against: + 1) as-is (absolute or CWD-relative) + 2) relative to this package directory + 3) <pkg_dir>/configs/<basename> + """ + if os.path.isabs(cfg_path) and os.path.exists(cfg_path): + return cfg_path + if os.path.exists(cfg_path): + return cfg_path + pkg_dir = os.path.dirname(__file__) + cand2 = os.path.normpath(os.path.join(pkg_dir, cfg_path)) + if os.path.exists(cand2): + return cand2 + cand3 = os.path.join(pkg_dir, "configs", os.path.basename(cfg_path)) + if os.path.exists(cand3): + return cand3 + raise FileNotFoundError(f"Config file not found. Tried: {cfg_path}, {cand2}, {cand3}") + + +# ----- +# Entry: get_dataloader +# ----- +def get_dataloader(cfg_path): + """ + Create train/val DataLoader from YAML config. + Handles SHD as true event dataset (encoder=None), others as synthetic placeholders. + """ + cfg_path_resolved = _resolve_cfg_path(cfg_path) + with open(cfg_path_resolved, "r") as f: + cfg = yaml.safe_load(f) + + dataset_name = cfg["dataset"].lower() + data_dir = cfg["data_dir"] + transforms = build_transforms(cfg.get("transforms", {})) + + if dataset_name == "shd": + # event dataset: do NOT use intensity encoders here + dt_ms = cfg.get("encoder", {}).get("dt_ms", 1.0) + seed = cfg.get("encoder", {}).get("seed", 42) + ds_train = SHDDataset( + data_dir, encoder=None, transforms=transforms, split="train", + dt_ms=dt_ms, seed=seed + ) + ds_val = SHDDataset( + data_dir, encoder=None, transforms=transforms, split="test", + dt_ms=dt_ms, seed=seed + ) + elif dataset_name == "ssc": + # placeholder path; later implement true SSC reader + encoder = build_encoder(cfg["encoder"]) + ds_train = SSCDataset(data_dir, encoder, transforms) + ds_val = SSCDataset(data_dir, encoder, transforms) + elif dataset_name == "dvs": + encoder = build_encoder(cfg["encoder"]) + ds_train = DVSDataset(data_dir, encoder, transforms) + ds_val = DVSDataset(data_dir, encoder, transforms) + else: + raise ValueError(f"Unknown dataset: {dataset_name}") + + train_loader = DataLoader( + ds_train, + batch_size=cfg.get("batch_size", 16), + shuffle=cfg.get("shuffle", True), + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + ) + val_loader = DataLoader( + ds_val, + batch_size=cfg.get("batch_size", 16), + shuffle=False, + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + ) + return train_loader, val_loader
\ No newline at end of file diff --git a/files/data_io/encoders/__init__.py b/files/data_io/encoders/__init__.py new file mode 100644 index 0000000..b2c5dc3 --- /dev/null +++ b/files/data_io/encoders/__init__.py @@ -0,0 +1 @@ +"""Encoder submodule: provides various spike encoding strategies.""" diff --git a/files/data_io/encoders/base_encoder.py b/files/data_io/encoders/base_encoder.py new file mode 100644 index 0000000..e40451f --- /dev/null +++ b/files/data_io/encoders/base_encoder.py @@ -0,0 +1,10 @@ +import numpy as np +import torch + +class BaseEncoder: + """Abstract base class for all encoders.""" + def encode(self, data: np.ndarray) -> torch.Tensor: + """ + Convert static data (e.g., image, waveform) into spike tensor (T, input_dim). + """ + raise NotImplementedError diff --git a/files/data_io/encoders/latency_encoder.py b/files/data_io/encoders/latency_encoder.py new file mode 100644 index 0000000..a7804ae --- /dev/null +++ b/files/data_io/encoders/latency_encoder.py @@ -0,0 +1,13 @@ +import numpy as np +import torch +from .base_encoder import BaseEncoder + +class LatencyEncoder(BaseEncoder): + """Encode input intensity into spike latency.""" + def __init__(self): + pass + + def encode(self, data: np.ndarray) -> torch.Tensor: + # TODO: map value→time delay + spikes = torch.zeros(10, data.size) # placeholder + return spikes diff --git a/files/data_io/encoders/poisson_encoder.py b/files/data_io/encoders/poisson_encoder.py new file mode 100644 index 0000000..0b404f7 --- /dev/null +++ b/files/data_io/encoders/poisson_encoder.py @@ -0,0 +1,91 @@ +from typing import Optional, Union +import numpy as np +import torch +from .base_encoder import BaseEncoder + +class PoissonEncoder(BaseEncoder): + r""" + PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling. + + Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D} + by sampling at each time step t and dimension d: + S[t, d] ~ Bernoulli( p[d] ), where p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ). + + Parameters + ---------- + max_rate : float + Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20. + T : int + Number of discrete time steps in the encoded spike train. Default: 50. + dt_ms : float + Time resolution per step in milliseconds. Default: 1.0 ms. + Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability. + seed : int or None + Optional RNG seed for reproducibility. If None, uses global RNG state. + + Notes + ----- + - Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]). + If 1D, we return S with shape (T, D). + If 2D (batched), we broadcast probabilities across batch and return (T, B, D). + - Intensities outside [0,1] will be clipped to [0,1]. + - Device of returned tensor follows torch default device (CPU) unless you move it later. + """ + + def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None): + super().__init__() + self.max_rate = float(max_rate) + self.T = int(T) + self.dt_ms = float(dt_ms) + self.seed = seed + # local generator for reproducibility if seed is provided + self._g = torch.Generator() + if seed is not None: + self._g.manual_seed(int(seed)) + + def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + if isinstance(data, torch.Tensor): + data = data.detach().cpu().numpy() + return np.asarray(data) + + def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + """ + Convert input intensities to Poisson spike trains. + + Parameters + ---------- + data : np.ndarray or torch.Tensor + Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not). + + Returns + ------- + spikes : torch.Tensor + Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}. + """ + x = self._ensure_numpy(data) + # clip to [0,1] + x = np.clip(x, 0.0, 1.0) + + # compute per-step probability + p_step = float(self.max_rate) * (self.dt_ms / 1000.0) # probability per step for unit intensity + # probability tensor (broadcast-friendly) + probs = x * p_step + probs = np.clip(probs, 0.0, 1.0) + + # create Bernoulli samples for T steps + # If input is 1D: probs.shape = (D,) -> output (T, D) + # If input is 2D: probs.shape = (B, D) -> output (T, B, D) + if probs.ndim == 1: + D = probs.shape[0] + probs_t = np.broadcast_to(probs, (self.T, D)) # (T, D) + probs_t = torch.from_numpy(probs_t.astype(np.float32)) + spikes = torch.bernoulli(probs_t, generator=self._g) + return spikes + elif probs.ndim == 2: + B, D = probs.shape + probs_t = np.broadcast_to(probs, (self.T, B, D)) # (T, B, D) + probs_t = torch.from_numpy(probs_t.astype(np.float32)) + spikes = torch.bernoulli(probs_t, generator=self._g) + return spikes + else: + raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}") diff --git a/files/data_io/encoders/rank_order_encoder.py b/files/data_io/encoders/rank_order_encoder.py new file mode 100644 index 0000000..9102e90 --- /dev/null +++ b/files/data_io/encoders/rank_order_encoder.py @@ -0,0 +1,10 @@ +import numpy as np +import torch +from .base_encoder import BaseEncoder + +class RankOrderEncoder(BaseEncoder): + """Encode by rank order of input features.""" + def encode(self, data: np.ndarray) -> torch.Tensor: + # TODO: implement rank order conversion + spikes = torch.zeros(10, data.size) # placeholder + return spikes diff --git a/files/data_io/transforms/__init__.py b/files/data_io/transforms/__init__.py new file mode 100644 index 0000000..f6aa496 --- /dev/null +++ b/files/data_io/transforms/__init__.py @@ -0,0 +1 @@ +"""Transforms for spike data (normalization, augmentation, etc.).""" diff --git a/files/data_io/transforms/normalization.py b/files/data_io/transforms/normalization.py new file mode 100644 index 0000000..c86b8c7 --- /dev/null +++ b/files/data_io/transforms/normalization.py @@ -0,0 +1,52 @@ +import torch + +class Normalize: + """Normalize spike input (z-score or min-max).""" + def __init__(self, mode: str = "zscore", eps: float = 1e-6): + """ + Parameters + ---------- + mode : str + One of {"zscore", "minmax"}. + - "zscore": per-sample, per-channel standardization across time. + - "minmax": per-sample, per-channel min-max scaling across time. + eps : float + Small constant to avoid division by zero. + """ + mode_l = str(mode).lower() + if mode_l not in {"zscore", "minmax"}: + raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}") + self.mode = mode_l + self.eps = float(eps) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply normalization. + Accepts tensors of shape (T, D) or (B, T, D). + Normalization is computed per-sample and per-channel over the time axis T. + """ + if not isinstance(x, torch.Tensor): + raise TypeError("Normalize expects a torch.Tensor") + if x.ndim == 2: + # (T, D) -> time dim = 0 + time_dim = 0 + keep_dims = True + elif x.ndim == 3: + # (B, T, D) -> time dim = 1 + time_dim = 1 + keep_dims = True + else: + raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}") + + if self.mode == "zscore": + mean_t = x.mean(dim=time_dim, keepdim=keep_dims) + # population std (unbiased=False) to avoid NaNs for small T + std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False) + x_norm = (x - mean_t) / (std_t + self.eps) + return x_norm + else: # "minmax" + min_t = x.amin(dim=time_dim, keepdim=keep_dims) + max_t = x.amax(dim=time_dim, keepdim=keep_dims) + denom = (max_t - min_t).clamp_min(self.eps) + x_scaled = (x - min_t) / denom + return x_scaled diff --git a/files/data_io/transforms/spike_augmentation.py b/files/data_io/transforms/spike_augmentation.py new file mode 100644 index 0000000..9b7b687 --- /dev/null +++ b/files/data_io/transforms/spike_augmentation.py @@ -0,0 +1,10 @@ +import torch + +class SpikeJitter: + """Add temporal jitter noise to spikes.""" + def __init__(self, std=0.01): + self.std = std + + def __call__(self, spikes: torch.Tensor) -> torch.Tensor: + # TODO: add random jitter to spike timings + return spikes diff --git a/files/data_io/utils/__init__.py b/files/data_io/utils/__init__.py new file mode 100644 index 0000000..ee3ab2f --- /dev/null +++ b/files/data_io/utils/__init__.py @@ -0,0 +1 @@ +"""Utility functions for file management, spike tools, and visualization.""" diff --git a/files/data_io/utils/file_utils.py b/files/data_io/utils/file_utils.py new file mode 100644 index 0000000..0a1e846 --- /dev/null +++ b/files/data_io/utils/file_utils.py @@ -0,0 +1,15 @@ +import os + +def ensure_dir(path: str): + """Ensure that a directory exists.""" + if not os.path.exists(path): + os.makedirs(path) + +def list_files(root: str, suffix: str): + """Recursively list files ending with suffix.""" + matches = [] + for dirpath, _, filenames in os.walk(root): + for f in filenames: + if f.endswith(suffix): + matches.append(os.path.join(dirpath, f)) + return matches diff --git a/files/data_io/utils/spike_tools.py b/files/data_io/utils/spike_tools.py new file mode 100644 index 0000000..968ee72 --- /dev/null +++ b/files/data_io/utils/spike_tools.py @@ -0,0 +1,10 @@ +import torch +import numpy as np + +def to_raster(spikes: torch.Tensor) -> np.ndarray: + """Convert spike tensor (T,B,N) to raster array (T,N).""" + return spikes.detach().cpu().numpy().mean(axis=1) + +def firing_rate(spikes: torch.Tensor, dt=1.0): + """Compute firing rate per neuron.""" + return spikes.sum(dim=0) / (spikes.shape[0] * dt) diff --git a/files/data_io/utils/visualize.py b/files/data_io/utils/visualize.py new file mode 100644 index 0000000..6f0de95 --- /dev/null +++ b/files/data_io/utils/visualize.py @@ -0,0 +1,19 @@ +import matplotlib.pyplot as plt +import torch + +def plot_raster(spikes: torch.Tensor, title=None): + """ + Plot raster diagram of spike activity (T,B,N) or (T,N). + """ + s = spikes.detach().cpu() + if s.ndim == 3: + s = s[:, 0, :] # take first batch + t, n = s.shape + for i in range(n): + times = torch.nonzero(s[:, i]).squeeze().numpy() + plt.scatter(times, i * np.ones_like(times), s=2, c='black') + plt.xlabel("Time step") + plt.ylabel("Neuron index") + if title: + plt.title(title) + plt.show() |
