From cd99d6b874d9d09b3bb87b8485cc787885af71f1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 13 Jan 2026 23:49:05 -0600 Subject: init commit --- files/tests/test_shd_loader_properties.py | 40 +++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 files/tests/test_shd_loader_properties.py (limited to 'files/tests/test_shd_loader_properties.py') diff --git a/files/tests/test_shd_loader_properties.py b/files/tests/test_shd_loader_properties.py new file mode 100644 index 0000000..237740c --- /dev/null +++ b/files/tests/test_shd_loader_properties.py @@ -0,0 +1,40 @@ +import torch +from files.data_io.dataset_loader import get_dataloader, SHDDataset + + +def test_shd_dataset_global_T_D_consistency(): + # Ensure SHDDataset computes global T and adaptive D once and applies consistently + ds = SHDDataset( + data_dir="/u/yurenh2/ml-projects/snn-training/files/data", + split="train", + dt_ms=1.0, + default_D=700, + ) + assert ds.T >= 1 + assert ds.D >= 700 # by construction at least default_D + x0, y0 = ds[0] + x1, y1 = ds[1] + assert isinstance(x0, torch.Tensor) and isinstance(x1, torch.Tensor) + assert x0.shape == (ds.T, ds.D) + assert x1.shape == (ds.T, ds.D) + assert isinstance(y0, int) and isinstance(y1, int) + # values should be finite + assert torch.isfinite(x0).all() + assert torch.isfinite(x1).all() + + +def test_dataloader_batch_shapes_and_finiteness(): + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + xb, yb = next(iter(train_loader)) + assert isinstance(xb, torch.Tensor) + assert xb.ndim == 3 # (B, T, D) + assert yb.ndim == 1 + B, T, D = xb.shape + assert B >= 1 and T >= 1 and D >= 1 + # finiteness and no NaNs + assert torch.isfinite(xb).all() + # After normalization (enabled in config), per-sample per-channel mean over time should be ~0 + mean_t = xb.mean(dim=1) # (B, D) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-3) + + -- cgit v1.2.3