diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-13 23:49:05 -0600 |
| commit | cd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch) | |
| tree | 59a233959932ca0e4f12f196275e07fcf443b33f /files/data_io/dataset_loader.py | |
init commit
Diffstat (limited to 'files/data_io/dataset_loader.py')
| -rw-r--r-- | files/data_io/dataset_loader.py | 281 |
1 files changed, 281 insertions, 0 deletions
diff --git a/files/data_io/dataset_loader.py b/files/data_io/dataset_loader.py new file mode 100644 index 0000000..983b67b --- /dev/null +++ b/files/data_io/dataset_loader.py @@ -0,0 +1,281 @@ +from typing import Optional +import yaml +import torch +from torch.utils.data import Dataset, DataLoader +import os +import h5py +import numpy as np +from .encoders import poisson_encoder, latency_encoder, rank_order_encoder +from .transforms import normalization, spike_augmentation +# from .utils import file_utils # 目前未使用可先注释 + + + +# ----- +# Base synthetic dataset placeholder +# ----- +class BaseDataset(Dataset): + """ + Abstract base class for datasets. + Each subclass must implement __getitem__ and __len__. + + This base implementation uses synthetic placeholders for quick smoke tests. + """ + + def __init__(self, data_dir, encoder, transforms=None): + self.data_dir = data_dir + self.encoder = encoder + self.transforms = transforms or [] + self.samples = list(range(200)) # placeholder synthetic samples + + def __getitem__(self, idx): + # synthetic static input -> encode to (T, D) via encoder + raw_data = torch.rand(128) # 128-dim intensity + label = torch.randint(0, 10, (1,)).item() + encoded = self.encoder.encode(raw_data.numpy()) if self.encoder is not None else raw_data + x = torch.as_tensor(encoded) + for t in self.transforms: + x = t(x) + return x, label + + def __len__(self): + return len(self.samples) + + +# ----- +# SHD dataset: true event data, read H5 per-sample; fixed global T; adaptive D +# ----- +import h5py # noqa: E402 + + +class SHDDataset(Dataset): + """ + SHD Dataset Loader (.h5) with time-adaptive binning and fixed global T. + + H5 structure (Zenke Lab convention): + - f["labels"][i] -> scalar label for sample i + - f["spikes"]["times"][i] -> 1D array of spike times (ms) for sample i + - f["spikes"]["units"][i] -> 1D array of channel ids for sample i + + We: + (1) scan once to determine global T = ceil(max_time / dt_ms) + (2) decide D from max unit id (fallback to default_D=700) + (3) in __getitem__, open H5, read ragged arrays for that sample, and bin to (T, D) + """ + + def __init__( + self, + data_dir: str, + encoder=None, # ignored for SHD (already spiking events) + transforms=None, + split: str = "train", + dt_ms: float = 1.0, + seed: Optional[int] = None, + default_D: int = 700 + ): + super().__init__() + self.data_dir = data_dir + self.transforms = transforms or [] + self.dt_ms = float(dt_ms) + self.seed = 42 if seed is None else int(seed) + self.encoder = None # IMPORTANT: do not apply intensity encoders to event data + self.default_D = int(default_D) + + fname = f"shd_{split}.h5" + self.path = os.path.join(self.data_dir, fname) + if not os.path.exists(self.path): + raise FileNotFoundError(f"SHD file not found: {self.path}") + + with h5py.File(self.path, "r") as f: + # labels is dense array + self.labels = np.array(f["labels"], dtype=np.int64) + self.N = int(self.labels.shape[0]) + + # ragged datasets for events + times_ds = f["spikes"]["times"] + units_ds = f["spikes"]["units"] + + # scan once to compute global T and adaptive D + t_max_global = 0.0 + max_unit = -1 + for i in range(self.N): + ti = times_ds[i] + ui = units_ds[i] + if ti.size > 0: + last_t = float(ti[-1]) # ms + if last_t > t_max_global: + t_max_global = last_t + if ui.size > 0: + uimax = int(ui.max()) + if uimax > max_unit: + max_unit = uimax + + # decide D + if max_unit >= 0: + self.D = max(max_unit + 1, self.default_D) + else: + self.D = self.default_D + + # decide T from global max time + self.T = int(np.ceil(t_max_global / self.dt_ms)) if t_max_global > 0 else 1 + + # rng in case transforms need it + self._rng = np.random.default_rng(self.seed) + + def __len__(self): + return self.N + + def __getitem__(self, idx: int): + # open file per-sample for worker safety + with h5py.File(self.path, "r") as f: + ti = f["spikes"]["times"][idx][:] + ui = f["spikes"]["units"][idx][:] + y = int(f["labels"][idx]) + + # bin events to (T, D) + spikes = np.zeros((self.T, self.D), dtype=np.float32) + if ti.size > 0: + bins = (ti / self.dt_ms).astype(np.int64) + bins = np.clip(bins, 0, self.T - 1) + ui = np.clip(ui.astype(np.int64), 0, self.D - 1) + spikes[bins, ui] = 1.0 # presence; if you prefer counts: +=1 then clip to 1 + + x = torch.from_numpy(spikes) + + # apply transforms (on torch tensor) + for tr in self.transforms: + x = tr(x) + + return x, y + + +# ----- +# SSC / DVS placeholders (still synthetic; implement real readers later) +# ----- +class SSCDataset(BaseDataset): + """Placeholder SSC dataset (synthetic).""" + pass + + +class DVSDataset(BaseDataset): + """Placeholder DVS dataset (synthetic).""" + pass + + +# ----- +# Helpers: encoders / transforms / cfg path resolution +# ----- +def build_encoder(cfg): + """ + Build encoder from config dict. + + Expected schema: + encoder: + type: poisson | latency | rank_order + # Poisson-only optional fields: + max_rate: 50 + T: 64 + dt_ms: 1.0 + seed: 123 + """ + etype = cfg["type"].lower() + if etype == "poisson": + return poisson_encoder.PoissonEncoder( + max_rate=cfg.get("max_rate", cfg.get("rate", 20)), + T=cfg.get("T", 50), + dt_ms=cfg.get("dt_ms", 1.0), + seed=cfg.get("seed", None), + ) + elif etype == "latency": + return latency_encoder.LatencyEncoder() + elif etype == "rank_order": + return rank_order_encoder.RankOrderEncoder() + else: + raise ValueError(f"Unknown encoder type: {etype}") + + +def build_transforms(cfg): + tlist = [] + if cfg.get("normalize", False): + tlist.append(normalization.Normalize()) + if cfg.get("spike_jitter", None) is not None: + tlist.append(spike_augmentation.SpikeJitter(std=cfg["spike_jitter"])) + return tlist + + +def _resolve_cfg_path(cfg_path: str) -> str: + """ + Resolve cfg_path against: + 1) as-is (absolute or CWD-relative) + 2) relative to this package directory + 3) <pkg_dir>/configs/<basename> + """ + if os.path.isabs(cfg_path) and os.path.exists(cfg_path): + return cfg_path + if os.path.exists(cfg_path): + return cfg_path + pkg_dir = os.path.dirname(__file__) + cand2 = os.path.normpath(os.path.join(pkg_dir, cfg_path)) + if os.path.exists(cand2): + return cand2 + cand3 = os.path.join(pkg_dir, "configs", os.path.basename(cfg_path)) + if os.path.exists(cand3): + return cand3 + raise FileNotFoundError(f"Config file not found. Tried: {cfg_path}, {cand2}, {cand3}") + + +# ----- +# Entry: get_dataloader +# ----- +def get_dataloader(cfg_path): + """ + Create train/val DataLoader from YAML config. + Handles SHD as true event dataset (encoder=None), others as synthetic placeholders. + """ + cfg_path_resolved = _resolve_cfg_path(cfg_path) + with open(cfg_path_resolved, "r") as f: + cfg = yaml.safe_load(f) + + dataset_name = cfg["dataset"].lower() + data_dir = cfg["data_dir"] + transforms = build_transforms(cfg.get("transforms", {})) + + if dataset_name == "shd": + # event dataset: do NOT use intensity encoders here + dt_ms = cfg.get("encoder", {}).get("dt_ms", 1.0) + seed = cfg.get("encoder", {}).get("seed", 42) + ds_train = SHDDataset( + data_dir, encoder=None, transforms=transforms, split="train", + dt_ms=dt_ms, seed=seed + ) + ds_val = SHDDataset( + data_dir, encoder=None, transforms=transforms, split="test", + dt_ms=dt_ms, seed=seed + ) + elif dataset_name == "ssc": + # placeholder path; later implement true SSC reader + encoder = build_encoder(cfg["encoder"]) + ds_train = SSCDataset(data_dir, encoder, transforms) + ds_val = SSCDataset(data_dir, encoder, transforms) + elif dataset_name == "dvs": + encoder = build_encoder(cfg["encoder"]) + ds_train = DVSDataset(data_dir, encoder, transforms) + ds_val = DVSDataset(data_dir, encoder, transforms) + else: + raise ValueError(f"Unknown dataset: {dataset_name}") + + train_loader = DataLoader( + ds_train, + batch_size=cfg.get("batch_size", 16), + shuffle=cfg.get("shuffle", True), + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + ) + val_loader = DataLoader( + ds_val, + batch_size=cfg.get("batch_size", 16), + shuffle=False, + num_workers=cfg.get("num_workers", 0), + pin_memory=cfg.get("pin_memory", False), + ) + return train_loader, val_loader
\ No newline at end of file |
