diff options
Diffstat (limited to 'files/tests/test_transforms_normalize.py')
| -rw-r--r-- | files/tests/test_transforms_normalize.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/files/tests/test_transforms_normalize.py b/files/tests/test_transforms_normalize.py new file mode 100644 index 0000000..cc5a3d4 --- /dev/null +++ b/files/tests/test_transforms_normalize.py @@ -0,0 +1,53 @@ +import torch +from files.data_io.transforms.normalization import Normalize + + +def test_normalize_zscore_2d_zero_mean_unitish_var(): + # (T, D) toy example + T, D = 5, 3 + x = torch.arange(T * D, dtype=torch.float32).reshape(T, D) + norm = Normalize(mode="zscore") + xz = norm(x) + # mean over time per channel should be ~0 + mean_t = xz.mean(dim=0) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5) + # std over time per channel should be ~1 + std_t = xz.std(dim=0, unbiased=False) + assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-5) + + +def test_normalize_zscore_3d_per_sample(): + # (B, T, D) + B, T, D = 2, 6, 4 + x = torch.randn(B, T, D) + norm = Normalize(mode="zscore") + xz = norm(x) + mean_t = xz.mean(dim=1) # (B, D) + std_t = xz.std(dim=1, unbiased=False) # (B, D) + assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5) + assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-4) + + +def test_normalize_minmax_range_01_2d(): + T, D = 7, 2 + x = torch.linspace(-3, 3, steps=T).unsqueeze(1).repeat(1, D) + norm = Normalize(mode="minmax") + xm = norm(x) + assert xm.min().item() >= -1e-6 + assert xm.max().item() <= 1 + 1e-6 + # Check endpoints map to 0 and 1 + assert torch.isclose(xm.min(), torch.tensor(0.0), atol=1e-6) + assert torch.isclose(xm.max(), torch.tensor(1.0), atol=1e-6) + + +def test_normalize_rejects_bad_ndim(): + x = torch.ones(1, 2, 3, 4) + norm = Normalize() + try: + _ = norm(x) + except ValueError: + pass + else: + raise AssertionError("Normalize should raise on x.ndim not in {2,3}") + + |
