summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dl/normalize/mnist_demo.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/dl/normalize/mnist_demo.py b/dl/normalize/mnist_demo.py
new file mode 100644
index 0000000..dc9e00c
--- /dev/null
+++ b/dl/normalize/mnist_demo.py
@@ -0,0 +1,33 @@
+from torchvision import transforms
+import torchvision
+import torch
+
+
+# global, whole training dataset
+# x' = (x-mean)/std
+# x'*std + mean => x
+
+# timm.data.IMAGENET_DEFAULT_MEAN: (0.485, 0.456, 0.406)
+# timm.data.IMAGENET_DEFAULT_STD: (0.229, 0.224, 0.225)
+transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.1307], std=[0.3081])
+ ])
+
+# MNIST dataset
+mnist = torchvision.datasets.MNIST(root='./data/',
+ train=True,
+ transform=transform,
+ download=True)
+batch_size = 32
+# Data loader
+data_loader = torch.utils.data.DataLoader(dataset=mnist,
+ batch_size=batch_size,
+ shuffle=True)
+
+epochs = 10
+
+for epoch in range(epochs):
+
+ for i, (images, t) in enumerate(data_loader):
+ print(images) \ No newline at end of file