summaryrefslogtreecommitdiff
path: root/learn_torch/basics/bn.py
diff options
context:
space:
mode:
Diffstat (limited to 'learn_torch/basics/bn.py')
-rw-r--r--learn_torch/basics/bn.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/learn_torch/basics/bn.py b/learn_torch/basics/bn.py
new file mode 100644
index 0000000..722497e
--- /dev/null
+++ b/learn_torch/basics/bn.py
@@ -0,0 +1,19 @@
+
+import torch
+from torch import nn
+
+
+if __name__ == '__main__':
+ m = nn.BatchNorm1d(3, momentum=None)
+ x1 = torch.randint(0, 5, (2, 3), dtype=torch.float32)
+ x2 = torch.randint(0, 5, (2, 3), dtype=torch.float32)
+
+ m(x1)
+ print(m.running_mean, m.running_var)
+ m(x2)
+ print(m.running_mean, m.running_var)
+
+ m.eval()
+ x3 = torch.randint(0, 5, (1, 3), dtype=torch.float32)
+ m(x3)
+