diff options
| author | zhang <zch921005@126.com> | 2022-09-13 23:21:25 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-09-13 23:21:25 +0800 |
| commit | 2fef28a07fcc9f43455b24188987f882220c05f2 (patch) | |
| tree | 75b3808c5daec81e25b4d27cee83432855deb8ad /learn_torch/basics | |
| parent | b7e7dbcfac5bf8907d7d1e06ee6de2597a4c80f0 (diff) | |
bn train vs. eval
Diffstat (limited to 'learn_torch/basics')
| -rw-r--r-- | learn_torch/basics/bn.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/learn_torch/basics/bn.py b/learn_torch/basics/bn.py new file mode 100644 index 0000000..722497e --- /dev/null +++ b/learn_torch/basics/bn.py @@ -0,0 +1,19 @@ + +import torch +from torch import nn + + +if __name__ == '__main__': + m = nn.BatchNorm1d(3, momentum=None) + x1 = torch.randint(0, 5, (2, 3), dtype=torch.float32) + x2 = torch.randint(0, 5, (2, 3), dtype=torch.float32) + + m(x1) + print(m.running_mean, m.running_var) + m(x2) + print(m.running_mean, m.running_var) + + m.eval() + x3 = torch.randint(0, 5, (1, 3), dtype=torch.float32) + m(x3) + |
