import torch from files.data_io.transforms.normalization import Normalize def test_normalize_zscore_2d_zero_mean_unitish_var(): # (T, D) toy example T, D = 5, 3 x = torch.arange(T * D, dtype=torch.float32).reshape(T, D) norm = Normalize(mode="zscore") xz = norm(x) # mean over time per channel should be ~0 mean_t = xz.mean(dim=0) assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5) # std over time per channel should be ~1 std_t = xz.std(dim=0, unbiased=False) assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-5) def test_normalize_zscore_3d_per_sample(): # (B, T, D) B, T, D = 2, 6, 4 x = torch.randn(B, T, D) norm = Normalize(mode="zscore") xz = norm(x) mean_t = xz.mean(dim=1) # (B, D) std_t = xz.std(dim=1, unbiased=False) # (B, D) assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5) assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-4) def test_normalize_minmax_range_01_2d(): T, D = 7, 2 x = torch.linspace(-3, 3, steps=T).unsqueeze(1).repeat(1, D) norm = Normalize(mode="minmax") xm = norm(x) assert xm.min().item() >= -1e-6 assert xm.max().item() <= 1 + 1e-6 # Check endpoints map to 0 and 1 assert torch.isclose(xm.min(), torch.tensor(0.0), atol=1e-6) assert torch.isclose(xm.max(), torch.tensor(1.0), atol=1e-6) def test_normalize_rejects_bad_ndim(): x = torch.ones(1, 2, 3, 4) norm = Normalize() try: _ = norm(x) except ValueError: pass else: raise AssertionError("Normalize should raise on x.ndim not in {2,3}")