diff options
Diffstat (limited to 'files/tests')
| -rw-r--r-- | files/tests/conftest.py | 11 | ||||
| -rw-r--r-- | files/tests/test_data_io.py | 11 | ||||
| -rw-r--r-- | files/tests/test_shd_loader_properties.py | 40 | ||||
| -rw-r--r-- | files/tests/test_train_smoke.py | 33 | ||||
| -rw-r--r-- | files/tests/test_transforms_normalize.py | 53 |
5 files changed, 148 insertions, 0 deletions
diff --git a/files/tests/conftest.py b/files/tests/conftest.py new file mode 100644 index 0000000..a63c259 --- /dev/null +++ b/files/tests/conftest.py @@ -0,0 +1,11 @@ +import os +import sys + +# Ensure project root is importable so `import files...` works when tests are under files/tests +_HERE = os.path.dirname(__file__) +_FILES_DIR = os.path.dirname(_HERE) +_ROOT = os.path.dirname(_FILES_DIR) +if _ROOT not in sys.path: + sys.path.insert(0, _ROOT) + + diff --git a/files/tests/test_data_io.py b/files/tests/test_data_io.py new file mode 100644 index 0000000..1f2ccd8 --- /dev/null +++ b/files/tests/test_data_io.py @@ -0,0 +1,11 @@ +import torch +from files.data_io.dataset_loader import get_dataloader + + +def test_dataloader_shape(): + """Smoke test: verify dataloader output shape.""" + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + x, y = next(iter(train_loader)) + assert isinstance(x, torch.Tensor) + assert x.ndim == 3 + assert y.ndim == 1 diff --git a/files/tests/test_shd_loader_properties.py b/files/tests/test_shd_loader_properties.py new file mode 100644 index 0000000..237740c --- /dev/null +++ b/files/tests/test_shd_loader_properties.py @@ -0,0 +1,40 @@ +import torch +from files.data_io.dataset_loader import get_dataloader, SHDDataset + + +def test_shd_dataset_global_T_D_consistency(): + # Ensure SHDDataset computes global T and adaptive D once and applies consistently + ds = SHDDataset( + data_dir="/u/yurenh2/ml-projects/snn-training/files/data", + split="train", + dt_ms=1.0, + default_D=700, + ) + assert ds.T >= 1 + assert ds.D >= 700 # by construction at least default_D + x0, y0 = ds[0] + x1, y1 = ds[1] + assert isinstance(x0, torch.Tensor) and isinstance(x1, torch.Tensor) + assert x0.shape == (ds.T, ds.D) + assert x1.shape == (ds.T, ds.D) + assert isinstance(y0, int) and isinstance(y1, int) + # values should be finite + assert torch.isfinite(x0).all() + assert torch.isfinite(x1).all() + + +def test_dataloader_batch_shapes_and_finiteness(): + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + xb, yb = next(iter(train_loader)) + assert isinstance(xb, torch.Tensor) + assert xb.ndim == 3 # (B, T, D) + assert yb.ndim == 1 + B, T, D = xb.shape + assert B >= 1 and T >= 1 and D >= 1 + # finiteness and no NaNs + assert torch.isfinite(xb).all() + # After normalization (enabled in config), per-sample per-channel mean over time should be ~0 + mean_t = xb.mean(dim=1) # (B, D) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-3) + + diff --git a/files/tests/test_train_smoke.py b/files/tests/test_train_smoke.py new file mode 100644 index 0000000..e04e40b --- /dev/null +++ b/files/tests/test_train_smoke.py @@ -0,0 +1,33 @@ +import torch +from files.models.snn import SimpleSNN +from files.data_io.dataset_loader import get_dataloader +import torch.nn as nn +import torch.optim as optim + + +def _one_step(model, xb, yb, lyapunov=False): + model.train() + opt = optim.Adam(model.parameters(), lr=1e-3) + ce = nn.CrossEntropyLoss() + opt.zero_grad(set_to_none=True) + logits, lyap = model(xb, compute_lyapunov=lyapunov) + loss = ce(logits, yb) + if lyapunov and lyap is not None: + loss = loss + 0.1 * (lyap - 0.0) ** 2 + loss.backward() + opt.step() + assert torch.isfinite(loss).all() + + +def test_train_step_baseline_and_lyapunov(): + train_loader, _ = get_dataloader("data_io/configs/shd.yaml") + xb, yb = next(iter(train_loader)) + B, T, D = xb.shape + C = 20 + model = SimpleSNN(input_dim=D, hidden_dim=64, num_classes=C) + # baseline + _one_step(model, xb, yb, lyapunov=False) + # lyapunov-regularized + _one_step(model, xb, yb, lyapunov=True) + + diff --git a/files/tests/test_transforms_normalize.py b/files/tests/test_transforms_normalize.py new file mode 100644 index 0000000..cc5a3d4 --- /dev/null +++ b/files/tests/test_transforms_normalize.py @@ -0,0 +1,53 @@ +import torch +from files.data_io.transforms.normalization import Normalize + + +def test_normalize_zscore_2d_zero_mean_unitish_var(): + # (T, D) toy example + T, D = 5, 3 + x = torch.arange(T * D, dtype=torch.float32).reshape(T, D) + norm = Normalize(mode="zscore") + xz = norm(x) + # mean over time per channel should be ~0 + mean_t = xz.mean(dim=0) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5) + # std over time per channel should be ~1 + std_t = xz.std(dim=0, unbiased=False) + assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-5) + + +def test_normalize_zscore_3d_per_sample(): + # (B, T, D) + B, T, D = 2, 6, 4 + x = torch.randn(B, T, D) + norm = Normalize(mode="zscore") + xz = norm(x) + mean_t = xz.mean(dim=1) # (B, D) + std_t = xz.std(dim=1, unbiased=False) # (B, D) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5) + assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-4) + + +def test_normalize_minmax_range_01_2d(): + T, D = 7, 2 + x = torch.linspace(-3, 3, steps=T).unsqueeze(1).repeat(1, D) + norm = Normalize(mode="minmax") + xm = norm(x) + assert xm.min().item() >= -1e-6 + assert xm.max().item() <= 1 + 1e-6 + # Check endpoints map to 0 and 1 + assert torch.isclose(xm.min(), torch.tensor(0.0), atol=1e-6) + assert torch.isclose(xm.max(), torch.tensor(1.0), atol=1e-6) + + +def test_normalize_rejects_bad_ndim(): + x = torch.ones(1, 2, 3, 4) + norm = Normalize() + try: + _ = norm(x) + except ValueError: + pass + else: + raise AssertionError("Normalize should raise on x.ndim not in {2,3}") + + |
