summaryrefslogtreecommitdiff
path: root/files/data_io/dataset_loader.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:50:59 -0600
commit00cf667cee7ffacb144d5805fc7e0ef443f3583a (patch)
tree77d20a3adaecf96bf3aff0612bdd3b5fa1a7dc7e /files/data_io/dataset_loader.py
parentc53c04aa1d6ff75cb478a9498c370baa929c74b6 (diff)
parentcd99d6b874d9d09b3bb87b8485cc787885af71f1 (diff)
Merge master into main
Diffstat (limited to 'files/data_io/dataset_loader.py')
-rw-r--r--files/data_io/dataset_loader.py281
1 files changed, 281 insertions, 0 deletions
diff --git a/files/data_io/dataset_loader.py b/files/data_io/dataset_loader.py
new file mode 100644
index 0000000..983b67b
--- /dev/null
+++ b/files/data_io/dataset_loader.py
@@ -0,0 +1,281 @@
+from typing import Optional
+import yaml
+import torch
+from torch.utils.data import Dataset, DataLoader
+import os
+import h5py
+import numpy as np
+from .encoders import poisson_encoder, latency_encoder, rank_order_encoder
+from .transforms import normalization, spike_augmentation
+# from .utils import file_utils # 目前未使用可先注释
+
+
+
+# -----
+# Base synthetic dataset placeholder
+# -----
+class BaseDataset(Dataset):
+ """
+ Abstract base class for datasets.
+ Each subclass must implement __getitem__ and __len__.
+
+ This base implementation uses synthetic placeholders for quick smoke tests.
+ """
+
+ def __init__(self, data_dir, encoder, transforms=None):
+ self.data_dir = data_dir
+ self.encoder = encoder
+ self.transforms = transforms or []
+ self.samples = list(range(200)) # placeholder synthetic samples
+
+ def __getitem__(self, idx):
+ # synthetic static input -> encode to (T, D) via encoder
+ raw_data = torch.rand(128) # 128-dim intensity
+ label = torch.randint(0, 10, (1,)).item()
+ encoded = self.encoder.encode(raw_data.numpy()) if self.encoder is not None else raw_data
+ x = torch.as_tensor(encoded)
+ for t in self.transforms:
+ x = t(x)
+ return x, label
+
+ def __len__(self):
+ return len(self.samples)
+
+
+# -----
+# SHD dataset: true event data, read H5 per-sample; fixed global T; adaptive D
+# -----
+import h5py # noqa: E402
+
+
+class SHDDataset(Dataset):
+ """
+ SHD Dataset Loader (.h5) with time-adaptive binning and fixed global T.
+
+ H5 structure (Zenke Lab convention):
+ - f["labels"][i] -> scalar label for sample i
+ - f["spikes"]["times"][i] -> 1D array of spike times (ms) for sample i
+ - f["spikes"]["units"][i] -> 1D array of channel ids for sample i
+
+ We:
+ (1) scan once to determine global T = ceil(max_time / dt_ms)
+ (2) decide D from max unit id (fallback to default_D=700)
+ (3) in __getitem__, open H5, read ragged arrays for that sample, and bin to (T, D)
+ """
+
+ def __init__(
+ self,
+ data_dir: str,
+ encoder=None, # ignored for SHD (already spiking events)
+ transforms=None,
+ split: str = "train",
+ dt_ms: float = 1.0,
+ seed: Optional[int] = None,
+ default_D: int = 700
+ ):
+ super().__init__()
+ self.data_dir = data_dir
+ self.transforms = transforms or []
+ self.dt_ms = float(dt_ms)
+ self.seed = 42 if seed is None else int(seed)
+ self.encoder = None # IMPORTANT: do not apply intensity encoders to event data
+ self.default_D = int(default_D)
+
+ fname = f"shd_{split}.h5"
+ self.path = os.path.join(self.data_dir, fname)
+ if not os.path.exists(self.path):
+ raise FileNotFoundError(f"SHD file not found: {self.path}")
+
+ with h5py.File(self.path, "r") as f:
+ # labels is dense array
+ self.labels = np.array(f["labels"], dtype=np.int64)
+ self.N = int(self.labels.shape[0])
+
+ # ragged datasets for events
+ times_ds = f["spikes"]["times"]
+ units_ds = f["spikes"]["units"]
+
+ # scan once to compute global T and adaptive D
+ t_max_global = 0.0
+ max_unit = -1
+ for i in range(self.N):
+ ti = times_ds[i]
+ ui = units_ds[i]
+ if ti.size > 0:
+ last_t = float(ti[-1]) # ms
+ if last_t > t_max_global:
+ t_max_global = last_t
+ if ui.size > 0:
+ uimax = int(ui.max())
+ if uimax > max_unit:
+ max_unit = uimax
+
+ # decide D
+ if max_unit >= 0:
+ self.D = max(max_unit + 1, self.default_D)
+ else:
+ self.D = self.default_D
+
+ # decide T from global max time
+ self.T = int(np.ceil(t_max_global / self.dt_ms)) if t_max_global > 0 else 1
+
+ # rng in case transforms need it
+ self._rng = np.random.default_rng(self.seed)
+
+ def __len__(self):
+ return self.N
+
+ def __getitem__(self, idx: int):
+ # open file per-sample for worker safety
+ with h5py.File(self.path, "r") as f:
+ ti = f["spikes"]["times"][idx][:]
+ ui = f["spikes"]["units"][idx][:]
+ y = int(f["labels"][idx])
+
+ # bin events to (T, D)
+ spikes = np.zeros((self.T, self.D), dtype=np.float32)
+ if ti.size > 0:
+ bins = (ti / self.dt_ms).astype(np.int64)
+ bins = np.clip(bins, 0, self.T - 1)
+ ui = np.clip(ui.astype(np.int64), 0, self.D - 1)
+ spikes[bins, ui] = 1.0 # presence; if you prefer counts: +=1 then clip to 1
+
+ x = torch.from_numpy(spikes)
+
+ # apply transforms (on torch tensor)
+ for tr in self.transforms:
+ x = tr(x)
+
+ return x, y
+
+
+# -----
+# SSC / DVS placeholders (still synthetic; implement real readers later)
+# -----
+class SSCDataset(BaseDataset):
+ """Placeholder SSC dataset (synthetic)."""
+ pass
+
+
+class DVSDataset(BaseDataset):
+ """Placeholder DVS dataset (synthetic)."""
+ pass
+
+
+# -----
+# Helpers: encoders / transforms / cfg path resolution
+# -----
+def build_encoder(cfg):
+ """
+ Build encoder from config dict.
+
+ Expected schema:
+ encoder:
+ type: poisson | latency | rank_order
+ # Poisson-only optional fields:
+ max_rate: 50
+ T: 64
+ dt_ms: 1.0
+ seed: 123
+ """
+ etype = cfg["type"].lower()
+ if etype == "poisson":
+ return poisson_encoder.PoissonEncoder(
+ max_rate=cfg.get("max_rate", cfg.get("rate", 20)),
+ T=cfg.get("T", 50),
+ dt_ms=cfg.get("dt_ms", 1.0),
+ seed=cfg.get("seed", None),
+ )
+ elif etype == "latency":
+ return latency_encoder.LatencyEncoder()
+ elif etype == "rank_order":
+ return rank_order_encoder.RankOrderEncoder()
+ else:
+ raise ValueError(f"Unknown encoder type: {etype}")
+
+
+def build_transforms(cfg):
+ tlist = []
+ if cfg.get("normalize", False):
+ tlist.append(normalization.Normalize())
+ if cfg.get("spike_jitter", None) is not None:
+ tlist.append(spike_augmentation.SpikeJitter(std=cfg["spike_jitter"]))
+ return tlist
+
+
+def _resolve_cfg_path(cfg_path: str) -> str:
+ """
+ Resolve cfg_path against:
+ 1) as-is (absolute or CWD-relative)
+ 2) relative to this package directory
+ 3) <pkg_dir>/configs/<basename>
+ """
+ if os.path.isabs(cfg_path) and os.path.exists(cfg_path):
+ return cfg_path
+ if os.path.exists(cfg_path):
+ return cfg_path
+ pkg_dir = os.path.dirname(__file__)
+ cand2 = os.path.normpath(os.path.join(pkg_dir, cfg_path))
+ if os.path.exists(cand2):
+ return cand2
+ cand3 = os.path.join(pkg_dir, "configs", os.path.basename(cfg_path))
+ if os.path.exists(cand3):
+ return cand3
+ raise FileNotFoundError(f"Config file not found. Tried: {cfg_path}, {cand2}, {cand3}")
+
+
+# -----
+# Entry: get_dataloader
+# -----
+def get_dataloader(cfg_path):
+ """
+ Create train/val DataLoader from YAML config.
+ Handles SHD as true event dataset (encoder=None), others as synthetic placeholders.
+ """
+ cfg_path_resolved = _resolve_cfg_path(cfg_path)
+ with open(cfg_path_resolved, "r") as f:
+ cfg = yaml.safe_load(f)
+
+ dataset_name = cfg["dataset"].lower()
+ data_dir = cfg["data_dir"]
+ transforms = build_transforms(cfg.get("transforms", {}))
+
+ if dataset_name == "shd":
+ # event dataset: do NOT use intensity encoders here
+ dt_ms = cfg.get("encoder", {}).get("dt_ms", 1.0)
+ seed = cfg.get("encoder", {}).get("seed", 42)
+ ds_train = SHDDataset(
+ data_dir, encoder=None, transforms=transforms, split="train",
+ dt_ms=dt_ms, seed=seed
+ )
+ ds_val = SHDDataset(
+ data_dir, encoder=None, transforms=transforms, split="test",
+ dt_ms=dt_ms, seed=seed
+ )
+ elif dataset_name == "ssc":
+ # placeholder path; later implement true SSC reader
+ encoder = build_encoder(cfg["encoder"])
+ ds_train = SSCDataset(data_dir, encoder, transforms)
+ ds_val = SSCDataset(data_dir, encoder, transforms)
+ elif dataset_name == "dvs":
+ encoder = build_encoder(cfg["encoder"])
+ ds_train = DVSDataset(data_dir, encoder, transforms)
+ ds_val = DVSDataset(data_dir, encoder, transforms)
+ else:
+ raise ValueError(f"Unknown dataset: {dataset_name}")
+
+ train_loader = DataLoader(
+ ds_train,
+ batch_size=cfg.get("batch_size", 16),
+ shuffle=cfg.get("shuffle", True),
+ num_workers=cfg.get("num_workers", 0),
+ pin_memory=cfg.get("pin_memory", False),
+ )
+ val_loader = DataLoader(
+ ds_val,
+ batch_size=cfg.get("batch_size", 16),
+ shuffle=False,
+ num_workers=cfg.get("num_workers", 0),
+ pin_memory=cfg.get("pin_memory", False),
+ )
+ return train_loader, val_loader \ No newline at end of file