In [1]:
import torch
from torch import nn
import numpy as np
from copy import deepcopy

### 1. module

In [2]:
m = nn.BatchNorm1d(3)

In [3]:
m

BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

### 2.1 m(x1)

In [4]:
x1 = torch.randint(0, 5, (2, 3), dtype=torch.float)

In [5]:
x1

tensor([[2., 3., 1.],
        [1., 3., 4.]])

In [7]:
x1.mean(dim=0), x1.var(dim=0, unbiased=False)

(tensor([1.5000, 3.0000, 2.5000]), tensor([0.2500, 0.0000, 2.2500]))

In [14]:
# biased (unbiased = False)
(x1 - x1.mean(dim=0))/torch.sqrt(x1.var(dim=0, unbiased=False) + 1e-5)

tensor([[ 1.0000,  0.0000, -1.0000],
        [-1.0000,  0.0000,  1.0000]])

In [9]:
m(x1)

tensor([[ 1.0000,  0.0000, -1.0000],
        [-1.0000,  0.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>)

In [10]:
last_mean, last_var = deepcopy(m.running_mean), deepcopy(m.running_var)

In [11]:
last_mean, last_var

(tensor([0.1500, 0.3000, 0.2500]), tensor([0.9500, 0.9000, 1.3500]))

In [12]:
(1-0.1)*0 + 0.1*x1.mean(dim=0)

tensor([0.1500, 0.3000, 0.2500])

In [15]:
# unbiased = True
(1-0.1)*torch.ones(3) + 0.1*x1.var(dim=0)

tensor([0.9500, 0.9000, 1.3500])

### 2.2 m(x2)

In [16]:
x2 = torch.randint(0, 5, (2, 3), dtype=torch.float)

In [17]:
x2

tensor([[0., 3., 0.],
        [3., 2., 2.]])

In [18]:
x2.mean(dim=0), x2.var(dim=0)

(tensor([1.5000, 2.5000, 1.0000]), tensor([4.5000, 0.5000, 2.0000]))

In [19]:
(x2 - x2.mean(dim=0)) / torch.sqrt(x2.var(dim=0, unbiased=False)+1e-05)

tensor([[-1.0000,  1.0000, -1.0000],
        [ 1.0000, -1.0000,  1.0000]])

In [20]:
m(x2)

tensor([[-1.0000,  1.0000, -1.0000],
        [ 1.0000, -1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>)

In [21]:
m.running_mean, m.running_var

(tensor([0.2850, 0.5200, 0.3250]), tensor([1.3050, 0.8600, 1.4150]))

In [22]:
(1-0.1)*last_mean + 0.1*x2.mean(dim=0)

tensor([0.2850, 0.5200, 0.3250])

In [23]:
(1-0.1)*last_var + 0.1*x2.var(dim=0)

tensor([1.3050, 0.8600, 1.4150])

### 3. eval mode

In [24]:
x3 = torch.randint(0, 5, (2, 3), dtype=torch.float)

In [25]:
x3

tensor([[1., 3., 3.],
        [2., 0., 3.]])

In [26]:
m.eval()

BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [27]:
m(x3)

tensor([[ 0.6259,  2.6742,  2.2488],
        [ 1.5013, -0.5607,  2.2488]], grad_fn=<NativeBatchNormBackward0>)

In [28]:
(x3 - x3.mean(dim=0))/torch.sqrt(x3.var(dim=0, unbiased=False) + 1e-5)

tensor([[-1.0000,  1.0000,  0.0000],
        [ 1.0000, -1.0000,  0.0000]])

In [29]:
(x3 - m.running_mean)/torch.sqrt(m.running_var+1e-5)

tensor([[ 0.6259,  2.6742,  2.2488],
        [ 1.5013, -0.5607,  2.2488]])