summaryrefslogtreecommitdiff
path: root/files/data_io
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/data_io
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files/data_io')
-rw-r--r--files/data_io/__init__.py3
-rw-r--r--files/data_io/benchmark_datasets.py360
-rw-r--r--files/data_io/configs/__init__.py0
-rw-r--r--files/data_io/configs/dvs.yaml0
-rw-r--r--files/data_io/configs/shd.yaml11
-rw-r--r--files/data_io/configs/ssc.yaml0
-rw-r--r--files/data_io/dataset_loader.py281
-rw-r--r--files/data_io/encoders/__init__.py1
-rw-r--r--files/data_io/encoders/base_encoder.py10
-rw-r--r--files/data_io/encoders/latency_encoder.py13
-rw-r--r--files/data_io/encoders/poisson_encoder.py91
-rw-r--r--files/data_io/encoders/rank_order_encoder.py10
-rw-r--r--files/data_io/transforms/__init__.py1
-rw-r--r--files/data_io/transforms/normalization.py52
-rw-r--r--files/data_io/transforms/spike_augmentation.py10
-rw-r--r--files/data_io/utils/__init__.py1
-rw-r--r--files/data_io/utils/file_utils.py15
-rw-r--r--files/data_io/utils/spike_tools.py10
-rw-r--r--files/data_io/utils/visualize.py19
19 files changed, 888 insertions, 0 deletions
diff --git a/files/data_io/__init__.py b/files/data_io/__init__.py
new file mode 100644
index 0000000..de423db
--- /dev/null
+++ b/files/data_io/__init__.py
@@ -0,0 +1,3 @@
+"""
+data_io package: unified data loading, encoding, and preprocessing for spiking neural networks.
+"""
diff --git a/files/data_io/benchmark_datasets.py b/files/data_io/benchmark_datasets.py
new file mode 100644
index 0000000..302ed51
--- /dev/null
+++ b/files/data_io/benchmark_datasets.py
@@ -0,0 +1,360 @@
+"""
+Challenging benchmark datasets for deep SNN evaluation.
+
+Datasets:
+1. Sequential MNIST (sMNIST) - pixel-by-pixel, 784 timesteps
+2. Permuted Sequential MNIST (psMNIST) - shuffled pixel order
+3. CIFAR-10 with rate coding
+4. DVS-CIFAR10 (requires tonic library)
+
+These benchmarks are harder than SHD and benefit from deeper networks.
+"""
+
+import os
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader
+
+
+class SequentialMNIST(Dataset):
+ """
+ Sequential MNIST - feed pixels one at a time.
+
+ Each 28x28 image becomes a sequence of 784 timesteps,
+ each with a single pixel intensity converted to spike probability.
+
+ This is MUCH harder than standard MNIST because:
+ - Network must remember information across 784 timesteps
+ - Tests long-range temporal dependencies
+ - Shallow networks fail due to vanishing gradients
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ permute: If True, use fixed random permutation (psMNIST)
+ spike_encoding: 'rate' (Poisson) or 'latency' or 'direct' (raw intensity)
+ max_rate: Maximum firing rate for rate coding
+ seed: Random seed for permutation
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ permute: bool = False,
+ spike_encoding: str = "rate",
+ max_rate: float = 100.0,
+ n_repeat: int = 1, # Repeat each pixel n times for more spikes
+ seed: int = 42,
+ download: bool = True,
+ ):
+ try:
+ from torchvision import datasets, transforms
+ except ImportError:
+ raise ImportError("torchvision required: pip install torchvision")
+
+ self.train = train
+ self.permute = permute
+ self.spike_encoding = spike_encoding
+ self.max_rate = max_rate
+ self.n_repeat = n_repeat
+
+ # Load MNIST
+ self.mnist = datasets.MNIST(
+ root=root,
+ train=train,
+ download=download,
+ transform=transforms.ToTensor(),
+ )
+
+ # Create fixed permutation for psMNIST
+ if permute:
+ rng = np.random.RandomState(seed)
+ self.perm = torch.from_numpy(rng.permutation(784))
+ else:
+ self.perm = None
+
+ def __len__(self):
+ return len(self.mnist)
+
+ def __getitem__(self, idx):
+ img, label = self.mnist[idx]
+
+ # Flatten to (784,)
+ pixels = img.view(-1)
+
+ # Apply permutation
+ if self.perm is not None:
+ pixels = pixels[self.perm]
+
+ # Convert to spike sequence (T, 1) where T = 784 * n_repeat
+ T = 784 * self.n_repeat
+
+ if self.spike_encoding == "direct":
+ # Direct intensity: repeat each pixel n_repeat times
+ spikes = pixels.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1)
+
+ elif self.spike_encoding == "rate":
+ # Rate coding: Poisson spikes based on intensity
+ probs = pixels * (self.max_rate / 1000.0) # Assuming 1ms bins
+ probs = probs.clamp(0, 1)
+ # Repeat and sample
+ probs_expanded = probs.unsqueeze(1).repeat(1, self.n_repeat).view(T, 1)
+ spikes = torch.bernoulli(probs_expanded)
+
+ elif self.spike_encoding == "latency":
+ # Latency coding: spike time proportional to intensity
+ # High intensity = early spike, low = late spike
+ spikes = torch.zeros(T, 1)
+ for i, p in enumerate(pixels):
+ if p > 0.1: # Threshold for spiking
+ # Spike time: higher intensity = earlier
+ spike_time = int((1 - p) * (self.n_repeat - 1))
+ t = i * self.n_repeat + spike_time
+ spikes[t, 0] = 1.0
+
+ return spikes, label
+
+
+class RateCodingCIFAR10(Dataset):
+ """
+ CIFAR-10 with rate coding for SNNs.
+
+ Converts 32x32x3 images to spike trains:
+ - Each pixel channel becomes a Poisson spike train
+ - Total input dimension: 32*32*3 = 3072
+ - Sequence length: T timesteps
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ T: Number of timesteps
+ max_rate: Maximum firing rate (Hz)
+ flatten: If True, flatten spatial dimensions
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ T: int = 100,
+ max_rate: float = 200.0,
+ flatten: bool = True,
+ download: bool = True,
+ ):
+ try:
+ from torchvision import datasets, transforms
+ except ImportError:
+ raise ImportError("torchvision required: pip install torchvision")
+
+ self.T = T
+ self.max_rate = max_rate
+ self.flatten = flatten
+
+ # Normalize to [0, 1]
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ ])
+
+ self.cifar = datasets.CIFAR10(
+ root=root,
+ train=train,
+ download=download,
+ transform=transform,
+ )
+
+ def __len__(self):
+ return len(self.cifar)
+
+ def __getitem__(self, idx):
+ img, label = self.cifar[idx] # (3, 32, 32)
+
+ if self.flatten:
+ img = img.view(-1) # (3072,)
+
+ # Rate coding
+ prob_per_step = img * (self.max_rate / 1000.0) # Assuming 1ms steps
+ prob_per_step = prob_per_step.clamp(0, 1)
+
+ # Generate spikes for T timesteps
+ if self.flatten:
+ probs = prob_per_step.unsqueeze(0).expand(self.T, -1) # (T, 3072)
+ else:
+ probs = prob_per_step.unsqueeze(0).expand(self.T, -1, -1, -1) # (T, 3, 32, 32)
+
+ spikes = torch.bernoulli(probs)
+
+ return spikes, label
+
+
+class DVSCIFAR10(Dataset):
+ """
+ DVS-CIFAR10 dataset wrapper.
+
+ Requires the 'tonic' library for neuromorphic datasets:
+ pip install tonic
+
+ DVS-CIFAR10 is recorded from a Dynamic Vision Sensor watching
+ CIFAR-10 images on a monitor. It's a standard neuromorphic benchmark.
+
+ Args:
+ root: Data directory
+ train: Train or test split
+ T: Number of time bins
+ spatial_downsample: Downsample spatial resolution
+ """
+
+ def __init__(
+ self,
+ root: str = "./data",
+ train: bool = True,
+ T: int = 100,
+ dt_ms: float = 10.0,
+ download: bool = True,
+ ):
+ try:
+ import tonic
+ from tonic import transforms as tonic_transforms
+ except ImportError:
+ raise ImportError(
+ "tonic library required for DVS datasets: pip install tonic"
+ )
+
+ self.T = T
+
+ # Time binning transform
+ sensor_size = tonic.datasets.CIFAR10DVS.sensor_size
+ frame_transform = tonic_transforms.ToFrame(
+ sensor_size=sensor_size,
+ time_window=dt_ms * 1000, # Convert to microseconds
+ )
+
+ self.dataset = tonic.datasets.CIFAR10DVS(
+ save_to=root,
+ train=train,
+ transform=frame_transform,
+ )
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ frames, label = self.dataset[idx] # (T', 2, H, W) - 2 polarities
+
+ # Convert to tensor and flatten spatial dims
+ frames = torch.from_numpy(frames).float()
+
+ # Adjust to target T
+ T_actual = frames.shape[0]
+ if T_actual > self.T:
+ # Subsample
+ indices = torch.linspace(0, T_actual - 1, self.T).long()
+ frames = frames[indices]
+ elif T_actual < self.T:
+ # Pad with zeros
+ pad = torch.zeros(self.T - T_actual, *frames.shape[1:])
+ frames = torch.cat([frames, pad], dim=0)
+
+ # Flatten: (T, 2, H, W) -> (T, 2*H*W)
+ frames = frames.view(self.T, -1)
+
+ return frames, label
+
+
+def get_benchmark_dataloader(
+ dataset_name: str,
+ batch_size: int = 64,
+ root: str = "./data",
+ **kwargs,
+) -> Tuple[DataLoader, DataLoader, dict]:
+ """
+ Get train and validation dataloaders for a benchmark dataset.
+
+ Args:
+ dataset_name: One of 'smnist', 'psmnist', 'cifar10', 'dvs_cifar10'
+ batch_size: Batch size
+ root: Data directory
+ **kwargs: Additional arguments passed to dataset
+
+ Returns:
+ train_loader, val_loader, info_dict
+ """
+
+ if dataset_name == "smnist":
+ train_ds = SequentialMNIST(root, train=True, permute=False, **kwargs)
+ val_ds = SequentialMNIST(root, train=False, permute=False, **kwargs)
+ info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10,
+ "description": "Sequential MNIST - 784 timesteps, 1 pixel at a time"}
+
+ elif dataset_name == "psmnist":
+ train_ds = SequentialMNIST(root, train=True, permute=True, **kwargs)
+ val_ds = SequentialMNIST(root, train=False, permute=True, **kwargs)
+ info = {"T": 784 * kwargs.get("n_repeat", 1), "D": 1, "classes": 10,
+ "description": "Permuted Sequential MNIST - shuffled pixel order, tests long-range memory"}
+
+ elif dataset_name == "cifar10":
+ T = kwargs.pop("T", 100)
+ train_ds = RateCodingCIFAR10(root, train=True, T=T, **kwargs)
+ val_ds = RateCodingCIFAR10(root, train=False, T=T, **kwargs)
+ info = {"T": T, "D": 3072, "classes": 10,
+ "description": "CIFAR-10 with rate coding"}
+
+ elif dataset_name == "dvs_cifar10":
+ train_ds = DVSCIFAR10(root, train=True, **kwargs)
+ val_ds = DVSCIFAR10(root, train=False, **kwargs)
+ info = {"T": kwargs.get("T", 100), "D": 2 * 128 * 128, "classes": 10,
+ "description": "DVS-CIFAR10 neuromorphic dataset"}
+
+ else:
+ raise ValueError(f"Unknown dataset: {dataset_name}. "
+ f"Options: smnist, psmnist, cifar10, dvs_cifar10")
+
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)
+
+ return train_loader, val_loader, info
+
+
+# Quick test
+if __name__ == "__main__":
+ print("Testing benchmark datasets...\n")
+
+ # Test sMNIST
+ print("1. Sequential MNIST")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "smnist", batch_size=32, n_repeat=1, spike_encoding="direct"
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ # Test psMNIST
+ print("\n2. Permuted Sequential MNIST")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "psmnist", batch_size=32, n_repeat=1, spike_encoding="direct"
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ # Test CIFAR-10
+ print("\n3. CIFAR-10 (rate coded)")
+ try:
+ train_loader, val_loader, info = get_benchmark_dataloader(
+ "cifar10", batch_size=32, T=50
+ )
+ x, y = next(iter(train_loader))
+ print(f" Shape: {x.shape}, Labels: {y.shape}")
+ print(f" Info: {info}")
+ except Exception as e:
+ print(f" Error: {e}")
+
+ print("\nDone!")
diff --git a/files/data_io/configs/__init__.py b/files/data_io/configs/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/files/data_io/configs/__init__.py
diff --git a/files/data_io/configs/dvs.yaml b/files/data_io/configs/dvs.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/files/data_io/configs/dvs.yaml
diff --git a/files/data_io/configs/shd.yaml b/files/data_io/configs/shd.yaml
new file mode 100644
index 0000000..c786934
--- /dev/null
+++ b/files/data_io/configs/shd.yaml
@@ -0,0 +1,11 @@
+dataset: SHD
+data_dir: /u/yurenh2/ml-projects/snn-training/files/data
+shuffle: true
+encoder:
+ type: poisson
+ max_rate: 50
+ dt_ms: 1.0
+ seed: 42 # 也可不写,默认 42
+transforms:
+ normalize: true
+ spike_jitter: 0.01
diff --git a/files/data_io/configs/ssc.yaml b/files/data_io/configs/ssc.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/files/data_io/configs/ssc.yaml
diff --git a/files/data_io/dataset_loader.py b/files/data_io/dataset_loader.py
new file mode 100644
index 0000000..983b67b
--- /dev/null
+++ b/files/data_io/dataset_loader.py
@@ -0,0 +1,281 @@
+from typing import Optional
+import yaml
+import torch
+from torch.utils.data import Dataset, DataLoader
+import os
+import h5py
+import numpy as np
+from .encoders import poisson_encoder, latency_encoder, rank_order_encoder
+from .transforms import normalization, spike_augmentation
+# from .utils import file_utils # 目前未使用可先注释
+
+
+
+# -----
+# Base synthetic dataset placeholder
+# -----
+class BaseDataset(Dataset):
+ """
+ Abstract base class for datasets.
+ Each subclass must implement __getitem__ and __len__.
+
+ This base implementation uses synthetic placeholders for quick smoke tests.
+ """
+
+ def __init__(self, data_dir, encoder, transforms=None):
+ self.data_dir = data_dir
+ self.encoder = encoder
+ self.transforms = transforms or []
+ self.samples = list(range(200)) # placeholder synthetic samples
+
+ def __getitem__(self, idx):
+ # synthetic static input -> encode to (T, D) via encoder
+ raw_data = torch.rand(128) # 128-dim intensity
+ label = torch.randint(0, 10, (1,)).item()
+ encoded = self.encoder.encode(raw_data.numpy()) if self.encoder is not None else raw_data
+ x = torch.as_tensor(encoded)
+ for t in self.transforms:
+ x = t(x)
+ return x, label
+
+ def __len__(self):
+ return len(self.samples)
+
+
+# -----
+# SHD dataset: true event data, read H5 per-sample; fixed global T; adaptive D
+# -----
+import h5py # noqa: E402
+
+
+class SHDDataset(Dataset):
+ """
+ SHD Dataset Loader (.h5) with time-adaptive binning and fixed global T.
+
+ H5 structure (Zenke Lab convention):
+ - f["labels"][i] -> scalar label for sample i
+ - f["spikes"]["times"][i] -> 1D array of spike times (ms) for sample i
+ - f["spikes"]["units"][i] -> 1D array of channel ids for sample i
+
+ We:
+ (1) scan once to determine global T = ceil(max_time / dt_ms)
+ (2) decide D from max unit id (fallback to default_D=700)
+ (3) in __getitem__, open H5, read ragged arrays for that sample, and bin to (T, D)
+ """
+
+ def __init__(
+ self,
+ data_dir: str,
+ encoder=None, # ignored for SHD (already spiking events)
+ transforms=None,
+ split: str = "train",
+ dt_ms: float = 1.0,
+ seed: Optional[int] = None,
+ default_D: int = 700
+ ):
+ super().__init__()
+ self.data_dir = data_dir
+ self.transforms = transforms or []
+ self.dt_ms = float(dt_ms)
+ self.seed = 42 if seed is None else int(seed)
+ self.encoder = None # IMPORTANT: do not apply intensity encoders to event data
+ self.default_D = int(default_D)
+
+ fname = f"shd_{split}.h5"
+ self.path = os.path.join(self.data_dir, fname)
+ if not os.path.exists(self.path):
+ raise FileNotFoundError(f"SHD file not found: {self.path}")
+
+ with h5py.File(self.path, "r") as f:
+ # labels is dense array
+ self.labels = np.array(f["labels"], dtype=np.int64)
+ self.N = int(self.labels.shape[0])
+
+ # ragged datasets for events
+ times_ds = f["spikes"]["times"]
+ units_ds = f["spikes"]["units"]
+
+ # scan once to compute global T and adaptive D
+ t_max_global = 0.0
+ max_unit = -1
+ for i in range(self.N):
+ ti = times_ds[i]
+ ui = units_ds[i]
+ if ti.size > 0:
+ last_t = float(ti[-1]) # ms
+ if last_t > t_max_global:
+ t_max_global = last_t
+ if ui.size > 0:
+ uimax = int(ui.max())
+ if uimax > max_unit:
+ max_unit = uimax
+
+ # decide D
+ if max_unit >= 0:
+ self.D = max(max_unit + 1, self.default_D)
+ else:
+ self.D = self.default_D
+
+ # decide T from global max time
+ self.T = int(np.ceil(t_max_global / self.dt_ms)) if t_max_global > 0 else 1
+
+ # rng in case transforms need it
+ self._rng = np.random.default_rng(self.seed)
+
+ def __len__(self):
+ return self.N
+
+ def __getitem__(self, idx: int):
+ # open file per-sample for worker safety
+ with h5py.File(self.path, "r") as f:
+ ti = f["spikes"]["times"][idx][:]
+ ui = f["spikes"]["units"][idx][:]
+ y = int(f["labels"][idx])
+
+ # bin events to (T, D)
+ spikes = np.zeros((self.T, self.D), dtype=np.float32)
+ if ti.size > 0:
+ bins = (ti / self.dt_ms).astype(np.int64)
+ bins = np.clip(bins, 0, self.T - 1)
+ ui = np.clip(ui.astype(np.int64), 0, self.D - 1)
+ spikes[bins, ui] = 1.0 # presence; if you prefer counts: +=1 then clip to 1
+
+ x = torch.from_numpy(spikes)
+
+ # apply transforms (on torch tensor)
+ for tr in self.transforms:
+ x = tr(x)
+
+ return x, y
+
+
+# -----
+# SSC / DVS placeholders (still synthetic; implement real readers later)
+# -----
+class SSCDataset(BaseDataset):
+ """Placeholder SSC dataset (synthetic)."""
+ pass
+
+
+class DVSDataset(BaseDataset):
+ """Placeholder DVS dataset (synthetic)."""
+ pass
+
+
+# -----
+# Helpers: encoders / transforms / cfg path resolution
+# -----
+def build_encoder(cfg):
+ """
+ Build encoder from config dict.
+
+ Expected schema:
+ encoder:
+ type: poisson | latency | rank_order
+ # Poisson-only optional fields:
+ max_rate: 50
+ T: 64
+ dt_ms: 1.0
+ seed: 123
+ """
+ etype = cfg["type"].lower()
+ if etype == "poisson":
+ return poisson_encoder.PoissonEncoder(
+ max_rate=cfg.get("max_rate", cfg.get("rate", 20)),
+ T=cfg.get("T", 50),
+ dt_ms=cfg.get("dt_ms", 1.0),
+ seed=cfg.get("seed", None),
+ )
+ elif etype == "latency":
+ return latency_encoder.LatencyEncoder()
+ elif etype == "rank_order":
+ return rank_order_encoder.RankOrderEncoder()
+ else:
+ raise ValueError(f"Unknown encoder type: {etype}")
+
+
+def build_transforms(cfg):
+ tlist = []
+ if cfg.get("normalize", False):
+ tlist.append(normalization.Normalize())
+ if cfg.get("spike_jitter", None) is not None:
+ tlist.append(spike_augmentation.SpikeJitter(std=cfg["spike_jitter"]))
+ return tlist
+
+
+def _resolve_cfg_path(cfg_path: str) -> str:
+ """
+ Resolve cfg_path against:
+ 1) as-is (absolute or CWD-relative)
+ 2) relative to this package directory
+ 3) <pkg_dir>/configs/<basename>
+ """
+ if os.path.isabs(cfg_path) and os.path.exists(cfg_path):
+ return cfg_path
+ if os.path.exists(cfg_path):
+ return cfg_path
+ pkg_dir = os.path.dirname(__file__)
+ cand2 = os.path.normpath(os.path.join(pkg_dir, cfg_path))
+ if os.path.exists(cand2):
+ return cand2
+ cand3 = os.path.join(pkg_dir, "configs", os.path.basename(cfg_path))
+ if os.path.exists(cand3):
+ return cand3
+ raise FileNotFoundError(f"Config file not found. Tried: {cfg_path}, {cand2}, {cand3}")
+
+
+# -----
+# Entry: get_dataloader
+# -----
+def get_dataloader(cfg_path):
+ """
+ Create train/val DataLoader from YAML config.
+ Handles SHD as true event dataset (encoder=None), others as synthetic placeholders.
+ """
+ cfg_path_resolved = _resolve_cfg_path(cfg_path)
+ with open(cfg_path_resolved, "r") as f:
+ cfg = yaml.safe_load(f)
+
+ dataset_name = cfg["dataset"].lower()
+ data_dir = cfg["data_dir"]
+ transforms = build_transforms(cfg.get("transforms", {}))
+
+ if dataset_name == "shd":
+ # event dataset: do NOT use intensity encoders here
+ dt_ms = cfg.get("encoder", {}).get("dt_ms", 1.0)
+ seed = cfg.get("encoder", {}).get("seed", 42)
+ ds_train = SHDDataset(
+ data_dir, encoder=None, transforms=transforms, split="train",
+ dt_ms=dt_ms, seed=seed
+ )
+ ds_val = SHDDataset(
+ data_dir, encoder=None, transforms=transforms, split="test",
+ dt_ms=dt_ms, seed=seed
+ )
+ elif dataset_name == "ssc":
+ # placeholder path; later implement true SSC reader
+ encoder = build_encoder(cfg["encoder"])
+ ds_train = SSCDataset(data_dir, encoder, transforms)
+ ds_val = SSCDataset(data_dir, encoder, transforms)
+ elif dataset_name == "dvs":
+ encoder = build_encoder(cfg["encoder"])
+ ds_train = DVSDataset(data_dir, encoder, transforms)
+ ds_val = DVSDataset(data_dir, encoder, transforms)
+ else:
+ raise ValueError(f"Unknown dataset: {dataset_name}")
+
+ train_loader = DataLoader(
+ ds_train,
+ batch_size=cfg.get("batch_size", 16),
+ shuffle=cfg.get("shuffle", True),
+ num_workers=cfg.get("num_workers", 0),
+ pin_memory=cfg.get("pin_memory", False),
+ )
+ val_loader = DataLoader(
+ ds_val,
+ batch_size=cfg.get("batch_size", 16),
+ shuffle=False,
+ num_workers=cfg.get("num_workers", 0),
+ pin_memory=cfg.get("pin_memory", False),
+ )
+ return train_loader, val_loader \ No newline at end of file
diff --git a/files/data_io/encoders/__init__.py b/files/data_io/encoders/__init__.py
new file mode 100644
index 0000000..b2c5dc3
--- /dev/null
+++ b/files/data_io/encoders/__init__.py
@@ -0,0 +1 @@
+"""Encoder submodule: provides various spike encoding strategies."""
diff --git a/files/data_io/encoders/base_encoder.py b/files/data_io/encoders/base_encoder.py
new file mode 100644
index 0000000..e40451f
--- /dev/null
+++ b/files/data_io/encoders/base_encoder.py
@@ -0,0 +1,10 @@
+import numpy as np
+import torch
+
+class BaseEncoder:
+ """Abstract base class for all encoders."""
+ def encode(self, data: np.ndarray) -> torch.Tensor:
+ """
+ Convert static data (e.g., image, waveform) into spike tensor (T, input_dim).
+ """
+ raise NotImplementedError
diff --git a/files/data_io/encoders/latency_encoder.py b/files/data_io/encoders/latency_encoder.py
new file mode 100644
index 0000000..a7804ae
--- /dev/null
+++ b/files/data_io/encoders/latency_encoder.py
@@ -0,0 +1,13 @@
+import numpy as np
+import torch
+from .base_encoder import BaseEncoder
+
+class LatencyEncoder(BaseEncoder):
+ """Encode input intensity into spike latency."""
+ def __init__(self):
+ pass
+
+ def encode(self, data: np.ndarray) -> torch.Tensor:
+ # TODO: map value→time delay
+ spikes = torch.zeros(10, data.size) # placeholder
+ return spikes
diff --git a/files/data_io/encoders/poisson_encoder.py b/files/data_io/encoders/poisson_encoder.py
new file mode 100644
index 0000000..0b404f7
--- /dev/null
+++ b/files/data_io/encoders/poisson_encoder.py
@@ -0,0 +1,91 @@
+from typing import Optional, Union
+import numpy as np
+import torch
+from .base_encoder import BaseEncoder
+
+class PoissonEncoder(BaseEncoder):
+ r"""
+ PoissonEncoder: convert static intensities to spike trains via per-time-step Bernoulli sampling.
+
+ Given a static input vector x \in [0,1]^{D}, we produce a spike tensor S \in {0,1}^{T \times D}
+ by sampling at each time step t and dimension d:
+ S[t, d] ~ Bernoulli( p[d] ), where p[d] = clip( x[d] * max_rate * (dt_ms / 1000), 0, 1 ).
+
+ Parameters
+ ----------
+ max_rate : float
+ Maximum firing rate under unit intensity (Hz). Typical: 20~200. Default: 20.
+ T : int
+ Number of discrete time steps in the encoded spike train. Default: 50.
+ dt_ms : float
+ Time resolution per step in milliseconds. Default: 1.0 ms.
+ Effective per-step probability uses factor (dt_ms/1000) to convert Hz to per-step probability.
+ seed : int or None
+ Optional RNG seed for reproducibility. If None, uses global RNG state.
+
+ Notes
+ -----
+ - Input `data` is expected to be a NumPy 1D array (shape [D]) or 2D array ([B, D]).
+ If 1D, we return S with shape (T, D).
+ If 2D (batched), we broadcast probabilities across batch and return (T, B, D).
+ - Intensities outside [0,1] will be clipped to [0,1].
+ - Device of returned tensor follows torch default device (CPU) unless you move it later.
+ """
+
+ def __init__(self, max_rate: float = 20.0, T: int = 50, dt_ms: float = 1.0, seed: Optional[int] = None):
+ super().__init__()
+ self.max_rate = float(max_rate)
+ self.T = int(T)
+ self.dt_ms = float(dt_ms)
+ self.seed = seed
+ # local generator for reproducibility if seed is provided
+ self._g = torch.Generator()
+ if seed is not None:
+ self._g.manual_seed(int(seed))
+
+ def _ensure_numpy(self, data: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
+ if isinstance(data, torch.Tensor):
+ data = data.detach().cpu().numpy()
+ return np.asarray(data)
+
+ def encode(self, data: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ """
+ Convert input intensities to Poisson spike trains.
+
+ Parameters
+ ----------
+ data : np.ndarray or torch.Tensor
+ Shape [D] or [B, D]. Values are assumed in [0,1] (we clip if not).
+
+ Returns
+ -------
+ spikes : torch.Tensor
+ Shape (T, D) or (T, B, D), dtype=torch.float32 with values {0.,1.}.
+ """
+ x = self._ensure_numpy(data)
+ # clip to [0,1]
+ x = np.clip(x, 0.0, 1.0)
+
+ # compute per-step probability
+ p_step = float(self.max_rate) * (self.dt_ms / 1000.0) # probability per step for unit intensity
+ # probability tensor (broadcast-friendly)
+ probs = x * p_step
+ probs = np.clip(probs, 0.0, 1.0)
+
+ # create Bernoulli samples for T steps
+ # If input is 1D: probs.shape = (D,) -> output (T, D)
+ # If input is 2D: probs.shape = (B, D) -> output (T, B, D)
+ if probs.ndim == 1:
+ D = probs.shape[0]
+ probs_t = np.broadcast_to(probs, (self.T, D)) # (T, D)
+ probs_t = torch.from_numpy(probs_t.astype(np.float32))
+ spikes = torch.bernoulli(probs_t, generator=self._g)
+ return spikes
+ elif probs.ndim == 2:
+ B, D = probs.shape
+ probs_t = np.broadcast_to(probs, (self.T, B, D)) # (T, B, D)
+ probs_t = torch.from_numpy(probs_t.astype(np.float32))
+ spikes = torch.bernoulli(probs_t, generator=self._g)
+ return spikes
+ else:
+ raise ValueError(f"PoissonEncoder expects data with ndim 1 or 2, got shape {probs.shape}")
diff --git a/files/data_io/encoders/rank_order_encoder.py b/files/data_io/encoders/rank_order_encoder.py
new file mode 100644
index 0000000..9102e90
--- /dev/null
+++ b/files/data_io/encoders/rank_order_encoder.py
@@ -0,0 +1,10 @@
+import numpy as np
+import torch
+from .base_encoder import BaseEncoder
+
+class RankOrderEncoder(BaseEncoder):
+ """Encode by rank order of input features."""
+ def encode(self, data: np.ndarray) -> torch.Tensor:
+ # TODO: implement rank order conversion
+ spikes = torch.zeros(10, data.size) # placeholder
+ return spikes
diff --git a/files/data_io/transforms/__init__.py b/files/data_io/transforms/__init__.py
new file mode 100644
index 0000000..f6aa496
--- /dev/null
+++ b/files/data_io/transforms/__init__.py
@@ -0,0 +1 @@
+"""Transforms for spike data (normalization, augmentation, etc.)."""
diff --git a/files/data_io/transforms/normalization.py b/files/data_io/transforms/normalization.py
new file mode 100644
index 0000000..c86b8c7
--- /dev/null
+++ b/files/data_io/transforms/normalization.py
@@ -0,0 +1,52 @@
+import torch
+
+class Normalize:
+ """Normalize spike input (z-score or min-max)."""
+ def __init__(self, mode: str = "zscore", eps: float = 1e-6):
+ """
+ Parameters
+ ----------
+ mode : str
+ One of {"zscore", "minmax"}.
+ - "zscore": per-sample, per-channel standardization across time.
+ - "minmax": per-sample, per-channel min-max scaling across time.
+ eps : float
+ Small constant to avoid division by zero.
+ """
+ mode_l = str(mode).lower()
+ if mode_l not in {"zscore", "minmax"}:
+ raise ValueError(f"Normalize mode must be 'zscore' or 'minmax', got {mode}")
+ self.mode = mode_l
+ self.eps = float(eps)
+
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Apply normalization.
+ Accepts tensors of shape (T, D) or (B, T, D).
+ Normalization is computed per-sample and per-channel over the time axis T.
+ """
+ if not isinstance(x, torch.Tensor):
+ raise TypeError("Normalize expects a torch.Tensor")
+ if x.ndim == 2:
+ # (T, D) -> time dim = 0
+ time_dim = 0
+ keep_dims = True
+ elif x.ndim == 3:
+ # (B, T, D) -> time dim = 1
+ time_dim = 1
+ keep_dims = True
+ else:
+ raise ValueError(f"Expected x.ndim in {{2,3}}, got {x.ndim}")
+
+ if self.mode == "zscore":
+ mean_t = x.mean(dim=time_dim, keepdim=keep_dims)
+ # population std (unbiased=False) to avoid NaNs for small T
+ std_t = x.std(dim=time_dim, keepdim=keep_dims, unbiased=False)
+ x_norm = (x - mean_t) / (std_t + self.eps)
+ return x_norm
+ else: # "minmax"
+ min_t = x.amin(dim=time_dim, keepdim=keep_dims)
+ max_t = x.amax(dim=time_dim, keepdim=keep_dims)
+ denom = (max_t - min_t).clamp_min(self.eps)
+ x_scaled = (x - min_t) / denom
+ return x_scaled
diff --git a/files/data_io/transforms/spike_augmentation.py b/files/data_io/transforms/spike_augmentation.py
new file mode 100644
index 0000000..9b7b687
--- /dev/null
+++ b/files/data_io/transforms/spike_augmentation.py
@@ -0,0 +1,10 @@
+import torch
+
+class SpikeJitter:
+ """Add temporal jitter noise to spikes."""
+ def __init__(self, std=0.01):
+ self.std = std
+
+ def __call__(self, spikes: torch.Tensor) -> torch.Tensor:
+ # TODO: add random jitter to spike timings
+ return spikes
diff --git a/files/data_io/utils/__init__.py b/files/data_io/utils/__init__.py
new file mode 100644
index 0000000..ee3ab2f
--- /dev/null
+++ b/files/data_io/utils/__init__.py
@@ -0,0 +1 @@
+"""Utility functions for file management, spike tools, and visualization."""
diff --git a/files/data_io/utils/file_utils.py b/files/data_io/utils/file_utils.py
new file mode 100644
index 0000000..0a1e846
--- /dev/null
+++ b/files/data_io/utils/file_utils.py
@@ -0,0 +1,15 @@
+import os
+
+def ensure_dir(path: str):
+ """Ensure that a directory exists."""
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def list_files(root: str, suffix: str):
+ """Recursively list files ending with suffix."""
+ matches = []
+ for dirpath, _, filenames in os.walk(root):
+ for f in filenames:
+ if f.endswith(suffix):
+ matches.append(os.path.join(dirpath, f))
+ return matches
diff --git a/files/data_io/utils/spike_tools.py b/files/data_io/utils/spike_tools.py
new file mode 100644
index 0000000..968ee72
--- /dev/null
+++ b/files/data_io/utils/spike_tools.py
@@ -0,0 +1,10 @@
+import torch
+import numpy as np
+
+def to_raster(spikes: torch.Tensor) -> np.ndarray:
+ """Convert spike tensor (T,B,N) to raster array (T,N)."""
+ return spikes.detach().cpu().numpy().mean(axis=1)
+
+def firing_rate(spikes: torch.Tensor, dt=1.0):
+ """Compute firing rate per neuron."""
+ return spikes.sum(dim=0) / (spikes.shape[0] * dt)
diff --git a/files/data_io/utils/visualize.py b/files/data_io/utils/visualize.py
new file mode 100644
index 0000000..6f0de95
--- /dev/null
+++ b/files/data_io/utils/visualize.py
@@ -0,0 +1,19 @@
+import matplotlib.pyplot as plt
+import torch
+
+def plot_raster(spikes: torch.Tensor, title=None):
+ """
+ Plot raster diagram of spike activity (T,B,N) or (T,N).
+ """
+ s = spikes.detach().cpu()
+ if s.ndim == 3:
+ s = s[:, 0, :] # take first batch
+ t, n = s.shape
+ for i in range(n):
+ times = torch.nonzero(s[:, i]).squeeze().numpy()
+ plt.scatter(times, i * np.ones_like(times), s=2, c='black')
+ plt.xlabel("Time step")
+ plt.ylabel("Neuron index")
+ if title:
+ plt.title(title)
+ plt.show()