1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
|
import torch
from files.data_io.transforms.normalization import Normalize
def test_normalize_zscore_2d_zero_mean_unitish_var():
# (T, D) toy example
T, D = 5, 3
x = torch.arange(T * D, dtype=torch.float32).reshape(T, D)
norm = Normalize(mode="zscore")
xz = norm(x)
# mean over time per channel should be ~0
mean_t = xz.mean(dim=0)
assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5)
# std over time per channel should be ~1
std_t = xz.std(dim=0, unbiased=False)
assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-5)
def test_normalize_zscore_3d_per_sample():
# (B, T, D)
B, T, D = 2, 6, 4
x = torch.randn(B, T, D)
norm = Normalize(mode="zscore")
xz = norm(x)
mean_t = xz.mean(dim=1) # (B, D)
std_t = xz.std(dim=1, unbiased=False) # (B, D)
assert torch.allclose(mean_t, torch.zeros_like(mean_t), atol=1e-5)
assert torch.allclose(std_t, torch.ones_like(std_t), atol=1e-4)
def test_normalize_minmax_range_01_2d():
T, D = 7, 2
x = torch.linspace(-3, 3, steps=T).unsqueeze(1).repeat(1, D)
norm = Normalize(mode="minmax")
xm = norm(x)
assert xm.min().item() >= -1e-6
assert xm.max().item() <= 1 + 1e-6
# Check endpoints map to 0 and 1
assert torch.isclose(xm.min(), torch.tensor(0.0), atol=1e-6)
assert torch.isclose(xm.max(), torch.tensor(1.0), atol=1e-6)
def test_normalize_rejects_bad_ndim():
x = torch.ones(1, 2, 3, 4)
norm = Normalize()
try:
_ = norm(x)
except ValueError:
pass
else:
raise AssertionError("Normalize should raise on x.ndim not in {2,3}")
|