summaryrefslogtreecommitdiff
path: root/files/tests/test_shd_loader_properties.py
blob: 237740c04d3e5fe8cc200633b8900d19bc79a3ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from files.data_io.dataset_loader import get_dataloader, SHDDataset


def test_shd_dataset_global_T_D_consistency():
    # Ensure SHDDataset computes global T and adaptive D once and applies consistently
    ds = SHDDataset(
        data_dir="/u/yurenh2/ml-projects/snn-training/files/data",
        split="train",
        dt_ms=1.0,
        default_D=700,
    )
    assert ds.T >= 1
    assert ds.D >= 700  # by construction at least default_D
    x0, y0 = ds[0]
    x1, y1 = ds[1]
    assert isinstance(x0, torch.Tensor) and isinstance(x1, torch.Tensor)
    assert x0.shape == (ds.T, ds.D)
    assert x1.shape == (ds.T, ds.D)
    assert isinstance(y0, int) and isinstance(y1, int)
    # values should be finite
    assert torch.isfinite(x0).all()
    assert torch.isfinite(x1).all()


def test_dataloader_batch_shapes_and_finiteness():
    train_loader, _ = get_dataloader("data_io/configs/shd.yaml")
    xb, yb = next(iter(train_loader))
    assert isinstance(xb, torch.Tensor)
    assert xb.ndim == 3  # (B, T, D)
    assert yb.ndim == 1
    B, T, D = xb.shape
    assert B >= 1 and T >= 1 and D >= 1
    # finiteness and no NaNs
    assert torch.isfinite(xb).all()
    # After normalization (enabled in config), per-sample per-channel mean over time should be ~0
    mean_t = xb.mean(dim=1)  # (B, D)
    assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-3)