summaryrefslogtreecommitdiff
path: root/files/tests/test_transforms_normalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'files/tests/test_transforms_normalize.py')
-rw-r--r--files/tests/test_transforms_normalize.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/files/tests/test_transforms_normalize.py b/files/tests/test_transforms_normalize.py
new file mode 100644
index 0000000..cc5a3d4
--- /dev/null
+++ b/files/tests/test_transforms_normalize.py
@@ -0,0 +1,53 @@
+import torch
+from files.data_io.transforms.normalization import Normalize
+
+
+def test_normalize_zscore_2d_zero_mean_unitish_var():
+ # (T, D) toy example
+ T, D = 5, 3
+ x = torch.arange(T * D, dtype=torch.float32).reshape(T, D)
+ norm = Normalize(mode="zscore")
+ xz = norm(x)
+ # mean over time per channel should be ~0
+ mean_t = xz.mean(dim=0)
+ assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5)
+ # std over time per channel should be ~1
+ std_t = xz.std(dim=0, unbiased=False)
+ assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-5)
+
+
+def test_normalize_zscore_3d_per_sample():
+ # (B, T, D)
+ B, T, D = 2, 6, 4
+ x = torch.randn(B, T, D)
+ norm = Normalize(mode="zscore")
+ xz = norm(x)
+ mean_t = xz.mean(dim=1) # (B, D)
+ std_t = xz.std(dim=1, unbiased=False) # (B, D)
+ assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5)
+ assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-4)
+
+
+def test_normalize_minmax_range_01_2d():
+ T, D = 7, 2
+ x = torch.linspace(-3, 3, steps=T).unsqueeze(1).repeat(1, D)
+ norm = Normalize(mode="minmax")
+ xm = norm(x)
+ assert xm.min().item() >= -1e-6
+ assert xm.max().item() <= 1 + 1e-6
+ # Check endpoints map to 0 and 1
+ assert torch.isclose(xm.min(), torch.tensor(0.0), atol=1e-6)
+ assert torch.isclose(xm.max(), torch.tensor(1.0), atol=1e-6)
+
+
+def test_normalize_rejects_bad_ndim():
+ x = torch.ones(1, 2, 3, 4)
+ norm = Normalize()
+ try:
+ _ = norm(x)
+ except ValueError:
+ pass
+ else:
+ raise AssertionError("Normalize should raise on x.ndim not in {2,3}")
+
+