from typing import Optional import yaml import torch from torch.utils.data import Dataset, DataLoader import os import h5py import numpy as np from .encoders import poisson_encoder, latency_encoder, rank_order_encoder from .transforms import normalization, spike_augmentation # from .utils import file_utils # 目前未使用可先注释 # ----- # Base synthetic dataset placeholder # ----- class BaseDataset(Dataset): """ Abstract base class for datasets. Each subclass must implement __getitem__ and __len__. This base implementation uses synthetic placeholders for quick smoke tests. """ def __init__(self, data_dir, encoder, transforms=None): self.data_dir = data_dir self.encoder = encoder self.transforms = transforms or [] self.samples = list(range(200)) # placeholder synthetic samples def __getitem__(self, idx): # synthetic static input -> encode to (T, D) via encoder raw_data = torch.rand(128) # 128-dim intensity label = torch.randint(0, 10, (1,)).item() encoded = self.encoder.encode(raw_data.numpy()) if self.encoder is not None else raw_data x = torch.as_tensor(encoded) for t in self.transforms: x = t(x) return x, label def __len__(self): return len(self.samples) # ----- # SHD dataset: true event data, read H5 per-sample; fixed global T; adaptive D # ----- import h5py # noqa: E402 class SHDDataset(Dataset): """ SHD Dataset Loader (.h5) with time-adaptive binning and fixed global T. H5 structure (Zenke Lab convention): - f["labels"][i] -> scalar label for sample i - f["spikes"]["times"][i] -> 1D array of spike times (ms) for sample i - f["spikes"]["units"][i] -> 1D array of channel ids for sample i We: (1) scan once to determine global T = ceil(max_time / dt_ms) (2) decide D from max unit id (fallback to default_D=700) (3) in __getitem__, open H5, read ragged arrays for that sample, and bin to (T, D) """ def __init__( self, data_dir: str, encoder=None, # ignored for SHD (already spiking events) transforms=None, split: str = "train", dt_ms: float = 1.0, seed: Optional[int] = None, default_D: int = 700 ): super().__init__() self.data_dir = data_dir self.transforms = transforms or [] self.dt_ms = float(dt_ms) self.seed = 42 if seed is None else int(seed) self.encoder = None # IMPORTANT: do not apply intensity encoders to event data self.default_D = int(default_D) fname = f"shd_{split}.h5" self.path = os.path.join(self.data_dir, fname) if not os.path.exists(self.path): raise FileNotFoundError(f"SHD file not found: {self.path}") with h5py.File(self.path, "r") as f: # labels is dense array self.labels = np.array(f["labels"], dtype=np.int64) self.N = int(self.labels.shape[0]) # ragged datasets for events times_ds = f["spikes"]["times"] units_ds = f["spikes"]["units"] # scan once to compute global T and adaptive D t_max_global = 0.0 max_unit = -1 for i in range(self.N): ti = times_ds[i] ui = units_ds[i] if ti.size > 0: last_t = float(ti[-1]) # ms if last_t > t_max_global: t_max_global = last_t if ui.size > 0: uimax = int(ui.max()) if uimax > max_unit: max_unit = uimax # decide D if max_unit >= 0: self.D = max(max_unit + 1, self.default_D) else: self.D = self.default_D # decide T from global max time self.T = int(np.ceil(t_max_global / self.dt_ms)) if t_max_global > 0 else 1 # rng in case transforms need it self._rng = np.random.default_rng(self.seed) def __len__(self): return self.N def __getitem__(self, idx: int): # open file per-sample for worker safety with h5py.File(self.path, "r") as f: ti = f["spikes"]["times"][idx][:] ui = f["spikes"]["units"][idx][:] y = int(f["labels"][idx]) # bin events to (T, D) spikes = np.zeros((self.T, self.D), dtype=np.float32) if ti.size > 0: bins = (ti / self.dt_ms).astype(np.int64) bins = np.clip(bins, 0, self.T - 1) ui = np.clip(ui.astype(np.int64), 0, self.D - 1) spikes[bins, ui] = 1.0 # presence; if you prefer counts: +=1 then clip to 1 x = torch.from_numpy(spikes) # apply transforms (on torch tensor) for tr in self.transforms: x = tr(x) return x, y # ----- # SSC / DVS placeholders (still synthetic; implement real readers later) # ----- class SSCDataset(BaseDataset): """Placeholder SSC dataset (synthetic).""" pass class DVSDataset(BaseDataset): """Placeholder DVS dataset (synthetic).""" pass # ----- # Helpers: encoders / transforms / cfg path resolution # ----- def build_encoder(cfg): """ Build encoder from config dict. Expected schema: encoder: type: poisson | latency | rank_order # Poisson-only optional fields: max_rate: 50 T: 64 dt_ms: 1.0 seed: 123 """ etype = cfg["type"].lower() if etype == "poisson": return poisson_encoder.PoissonEncoder( max_rate=cfg.get("max_rate", cfg.get("rate", 20)), T=cfg.get("T", 50), dt_ms=cfg.get("dt_ms", 1.0), seed=cfg.get("seed", None), ) elif etype == "latency": return latency_encoder.LatencyEncoder() elif etype == "rank_order": return rank_order_encoder.RankOrderEncoder() else: raise ValueError(f"Unknown encoder type: {etype}") def build_transforms(cfg): tlist = [] if cfg.get("normalize", False): tlist.append(normalization.Normalize()) if cfg.get("spike_jitter", None) is not None: tlist.append(spike_augmentation.SpikeJitter(std=cfg["spike_jitter"])) return tlist def _resolve_cfg_path(cfg_path: str) -> str: """ Resolve cfg_path against: 1) as-is (absolute or CWD-relative) 2) relative to this package directory 3) /configs/ """ if os.path.isabs(cfg_path) and os.path.exists(cfg_path): return cfg_path if os.path.exists(cfg_path): return cfg_path pkg_dir = os.path.dirname(__file__) cand2 = os.path.normpath(os.path.join(pkg_dir, cfg_path)) if os.path.exists(cand2): return cand2 cand3 = os.path.join(pkg_dir, "configs", os.path.basename(cfg_path)) if os.path.exists(cand3): return cand3 raise FileNotFoundError(f"Config file not found. Tried: {cfg_path}, {cand2}, {cand3}") # ----- # Entry: get_dataloader # ----- def get_dataloader(cfg_path): """ Create train/val DataLoader from YAML config. Handles SHD as true event dataset (encoder=None), others as synthetic placeholders. """ cfg_path_resolved = _resolve_cfg_path(cfg_path) with open(cfg_path_resolved, "r") as f: cfg = yaml.safe_load(f) dataset_name = cfg["dataset"].lower() data_dir = cfg["data_dir"] transforms = build_transforms(cfg.get("transforms", {})) if dataset_name == "shd": # event dataset: do NOT use intensity encoders here dt_ms = cfg.get("encoder", {}).get("dt_ms", 1.0) seed = cfg.get("encoder", {}).get("seed", 42) ds_train = SHDDataset( data_dir, encoder=None, transforms=transforms, split="train", dt_ms=dt_ms, seed=seed ) ds_val = SHDDataset( data_dir, encoder=None, transforms=transforms, split="test", dt_ms=dt_ms, seed=seed ) elif dataset_name == "ssc": # placeholder path; later implement true SSC reader encoder = build_encoder(cfg["encoder"]) ds_train = SSCDataset(data_dir, encoder, transforms) ds_val = SSCDataset(data_dir, encoder, transforms) elif dataset_name == "dvs": encoder = build_encoder(cfg["encoder"]) ds_train = DVSDataset(data_dir, encoder, transforms) ds_val = DVSDataset(data_dir, encoder, transforms) else: raise ValueError(f"Unknown dataset: {dataset_name}") train_loader = DataLoader( ds_train, batch_size=cfg.get("batch_size", 16), shuffle=cfg.get("shuffle", True), num_workers=cfg.get("num_workers", 0), pin_memory=cfg.get("pin_memory", False), ) val_loader = DataLoader( ds_val, batch_size=cfg.get("batch_size", 16), shuffle=False, num_workers=cfg.get("num_workers", 0), pin_memory=cfg.get("pin_memory", False), ) return train_loader, val_loader