summaryrefslogtreecommitdiff
path: root/files/tests/test_data_io.py
blob: 1f2ccd8902b58ee00d2ffc0c9223ff894857bc69 (plain)
1
2
3
4
5
6
7
8
9
10
11
import torch
from files.data_io.dataset_loader import get_dataloader


def test_dataloader_shape():
    """Smoke test: verify dataloader output shape."""
    train_loader, _ = get_dataloader("data_io/configs/shd.yaml")
    x, y = next(iter(train_loader))
    assert isinstance(x, torch.Tensor)
    assert x.ndim == 3
    assert y.ndim == 1