From 2fef28a07fcc9f43455b24188987f882220c05f2 Mon Sep 17 00:00:00 2001 From: zhang Date: Tue, 13 Sep 2022 23:21:25 +0800 Subject: bn train vs. eval --- learn_torch/basics/bn.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 learn_torch/basics/bn.py (limited to 'learn_torch/basics') diff --git a/learn_torch/basics/bn.py b/learn_torch/basics/bn.py new file mode 100644 index 0000000..722497e --- /dev/null +++ b/learn_torch/basics/bn.py @@ -0,0 +1,19 @@ + +import torch +from torch import nn + + +if __name__ == '__main__': + m = nn.BatchNorm1d(3, momentum=None) + x1 = torch.randint(0, 5, (2, 3), dtype=torch.float32) + x2 = torch.randint(0, 5, (2, 3), dtype=torch.float32) + + m(x1) + print(m.running_mean, m.running_var) + m(x2) + print(m.running_mean, m.running_var) + + m.eval() + x3 = torch.randint(0, 5, (1, 3), dtype=torch.float32) + m(x3) + -- cgit v1.2.3