From b7e7dbcfac5bf8907d7d1e06ee6de2597a4c80f0 Mon Sep 17 00:00:00 2001 From: zhang Date: Tue, 6 Sep 2022 07:30:30 +0800 Subject: residual connection --- dl/normalize/mnist_demo.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 dl/normalize/mnist_demo.py (limited to 'dl') diff --git a/dl/normalize/mnist_demo.py b/dl/normalize/mnist_demo.py new file mode 100644 index 0000000..dc9e00c --- /dev/null +++ b/dl/normalize/mnist_demo.py @@ -0,0 +1,33 @@ +from torchvision import transforms +import torchvision +import torch + + +# global, whole training dataset +# x' = (x-mean)/std +# x'*std + mean => x + +# timm.data.IMAGENET_DEFAULT_MEAN: (0.485, 0.456, 0.406) +# timm.data.IMAGENET_DEFAULT_STD: (0.229, 0.224, 0.225) +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.1307], std=[0.3081]) + ]) + +# MNIST dataset +mnist = torchvision.datasets.MNIST(root='./data/', + train=True, + transform=transform, + download=True) +batch_size = 32 +# Data loader +data_loader = torch.utils.data.DataLoader(dataset=mnist, + batch_size=batch_size, + shuffle=True) + +epochs = 10 + +for epoch in range(epochs): + + for i, (images, t) in enumerate(data_loader): + print(images) \ No newline at end of file -- cgit v1.2.3