summaryrefslogtreecommitdiff
path: root/files/tests/test_shd_loader_properties.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-13 23:49:05 -0600
commitcd99d6b874d9d09b3bb87b8485cc787885af71f1 (patch)
tree59a233959932ca0e4f12f196275e07fcf443b33f /files/tests/test_shd_loader_properties.py
init commit
Diffstat (limited to 'files/tests/test_shd_loader_properties.py')
-rw-r--r--files/tests/test_shd_loader_properties.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/files/tests/test_shd_loader_properties.py b/files/tests/test_shd_loader_properties.py
new file mode 100644
index 0000000..237740c
--- /dev/null
+++ b/files/tests/test_shd_loader_properties.py
@@ -0,0 +1,40 @@
+import torch
+from files.data_io.dataset_loader import get_dataloader, SHDDataset
+
+
+def test_shd_dataset_global_T_D_consistency():
+ # Ensure SHDDataset computes global T and adaptive D once and applies consistently
+ ds = SHDDataset(
+ data_dir="/u/yurenh2/ml-projects/snn-training/files/data",
+ split="train",
+ dt_ms=1.0,
+ default_D=700,
+ )
+ assert ds.T >= 1
+ assert ds.D >= 700 # by construction at least default_D
+ x0, y0 = ds[0]
+ x1, y1 = ds[1]
+ assert isinstance(x0, torch.Tensor) and isinstance(x1, torch.Tensor)
+ assert x0.shape == (ds.T, ds.D)
+ assert x1.shape == (ds.T, ds.D)
+ assert isinstance(y0, int) and isinstance(y1, int)
+ # values should be finite
+ assert torch.isfinite(x0).all()
+ assert torch.isfinite(x1).all()
+
+
+def test_dataloader_batch_shapes_and_finiteness():
+ train_loader, _ = get_dataloader("data_io/configs/shd.yaml")
+ xb, yb = next(iter(train_loader))
+ assert isinstance(xb, torch.Tensor)
+ assert xb.ndim == 3 # (B, T, D)
+ assert yb.ndim == 1
+ B, T, D = xb.shape
+ assert B >= 1 and T >= 1 and D >= 1
+ # finiteness and no NaNs
+ assert torch.isfinite(xb).all()
+ # After normalization (enabled in config), per-sample per-channel mean over time should be ~0
+ mean_t = xb.mean(dim=1) # (B, D)
+ assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-3)
+
+