summaryrefslogtreecommitdiff
path: root/files/tests/test_transforms_normalize.py
blob: cc5a3d4c0864bc2488ff1300e669c8be997460f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from files.data_io.transforms.normalization import Normalize


def test_normalize_zscore_2d_zero_mean_unitish_var():
    # (T, D) toy example
    T, D = 5, 3
    x = torch.arange(T * D, dtype=torch.float32).reshape(T, D)
    norm = Normalize(mode="zscore")
    xz = norm(x)
    # mean over time per channel should be ~0
    mean_t = xz.mean(dim=0)
    assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5)
    # std over time per channel should be ~1
    std_t = xz.std(dim=0, unbiased=False)
    assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-5)


def test_normalize_zscore_3d_per_sample():
    # (B, T, D)
    B, T, D = 2, 6, 4
    x = torch.randn(B, T, D)
    norm = Normalize(mode="zscore")
    xz = norm(x)
    mean_t = xz.mean(dim=1)  # (B, D)
    std_t = xz.std(dim=1, unbiased=False)  # (B, D)
    assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5)
    assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-4)


def test_normalize_minmax_range_01_2d():
    T, D = 7, 2
    x = torch.linspace(-3, 3, steps=T).unsqueeze(1).repeat(1, D)
    norm = Normalize(mode="minmax")
    xm = norm(x)
    assert xm.min().item() >= -1e-6
    assert xm.max().item() <= 1 + 1e-6
    # Check endpoints map to 0 and 1
    assert torch.isclose(xm.min(), torch.tensor(0.0), atol=1e-6)
    assert torch.isclose(xm.max(), torch.tensor(1.0), atol=1e-6)


def test_normalize_rejects_bad_ndim():
    x = torch.ones(1, 2, 3, 4)
    norm = Normalize()
    try:
        _ = norm(x)
    except ValueError:
        pass
    else:
        raise AssertionError("Normalize should raise on x.ndim not in {2,3}")