diff options
Diffstat (limited to 'dl')
| -rw-r--r-- | dl/normalize/mnist_demo.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/dl/normalize/mnist_demo.py b/dl/normalize/mnist_demo.py new file mode 100644 index 0000000..dc9e00c --- /dev/null +++ b/dl/normalize/mnist_demo.py @@ -0,0 +1,33 @@ +from torchvision import transforms +import torchvision +import torch + + +# global, whole training dataset +# x' = (x-mean)/std +# x'*std + mean => x + +# timm.data.IMAGENET_DEFAULT_MEAN: (0.485, 0.456, 0.406) +# timm.data.IMAGENET_DEFAULT_STD: (0.229, 0.224, 0.225) +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.1307], std=[0.3081]) + ]) + +# MNIST dataset +mnist = torchvision.datasets.MNIST(root='./data/', + train=True, + transform=transform, + download=True) +batch_size = 32 +# Data loader +data_loader = torch.utils.data.DataLoader(dataset=mnist, + batch_size=batch_size, + shuffle=True) + +epochs = 10 + +for epoch in range(epochs): + + for i, (images, t) in enumerate(data_loader): + print(images)
\ No newline at end of file |
