import torch from files.data_io.dataset_loader import get_dataloader, SHDDataset def test_shd_dataset_global_T_D_consistency(): # Ensure SHDDataset computes global T and adaptive D once and applies consistently ds = SHDDataset( data_dir="/u/yurenh2/ml-projects/snn-training/files/data", split="train", dt_ms=1.0, default_D=700, ) assert ds.T >= 1 assert ds.D >= 700 # by construction at least default_D x0, y0 = ds[0] x1, y1 = ds[1] assert isinstance(x0, torch.Tensor) and isinstance(x1, torch.Tensor) assert x0.shape == (ds.T, ds.D) assert x1.shape == (ds.T, ds.D) assert isinstance(y0, int) and isinstance(y1, int) # values should be finite assert torch.isfinite(x0).all() assert torch.isfinite(x1).all() def test_dataloader_batch_shapes_and_finiteness(): train_loader, _ = get_dataloader("data_io/configs/shd.yaml") xb, yb = next(iter(train_loader)) assert isinstance(xb, torch.Tensor) assert xb.ndim == 3 # (B, T, D) assert yb.ndim == 1 B, T, D = xb.shape assert B >= 1 and T >= 1 and D >= 1 # finiteness and no NaNs assert torch.isfinite(xb).all() # After normalization (enabled in config), per-sample per-channel mean over time should be ~0 mean_t = xb.mean(dim=1) # (B, D) assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-3)