blob: 1f2ccd8902b58ee00d2ffc0c9223ff894857bc69 (
plain)
1
2
3
4
5
6
7
8
9
10
11
|
import torch
from files.data_io.dataset_loader import get_dataloader
def test_dataloader_shape():
"""Smoke test: verify dataloader output shape."""
train_loader, _ = get_dataloader("data_io/configs/shd.yaml")
x, y = next(iter(train_loader))
assert isinstance(x, torch.Tensor)
assert x.ndim == 3
assert y.ndim == 1
|