1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
import torch
from files.data_io.dataset_loader import get_dataloader, SHDDataset
def test_shd_dataset_global_T_D_consistency():
# Ensure SHDDataset computes global T and adaptive D once and applies consistently
ds = SHDDataset(
data_dir="/u/yurenh2/ml-projects/snn-training/files/data",
split="train",
dt_ms=1.0,
default_D=700,
)
assert ds.T >= 1
assert ds.D >= 700 # by construction at least default_D
x0, y0 = ds[0]
x1, y1 = ds[1]
assert isinstance(x0, torch.Tensor) and isinstance(x1, torch.Tensor)
assert x0.shape == (ds.T, ds.D)
assert x1.shape == (ds.T, ds.D)
assert isinstance(y0, int) and isinstance(y1, int)
# values should be finite
assert torch.isfinite(x0).all()
assert torch.isfinite(x1).all()
def test_dataloader_batch_shapes_and_finiteness():
train_loader, _ = get_dataloader("data_io/configs/shd.yaml")
xb, yb = next(iter(train_loader))
assert isinstance(xb, torch.Tensor)
assert xb.ndim == 3 # (B, T, D)
assert yb.ndim == 1
B, T, D = xb.shape
assert B >= 1 and T >= 1 and D >= 1
# finiteness and no NaNs
assert torch.isfinite(xb).all()
# After normalization (enabled in config), per-sample per-channel mean over time should be ~0
mean_t = xb.mean(dim=1) # (B, D)
assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-3)
|