diff options
| author | zhang <zch921005@126.com> | 2022-09-06 07:30:30 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-09-06 07:30:30 +0800 |
| commit | b7e7dbcfac5bf8907d7d1e06ee6de2597a4c80f0 (patch) | |
| tree | 8d26e4d425eca174266764d65b1be2e9b5d58039 /dl | |
| parent | a1930ed563d5a3905c9504dbcff2fb00653233da (diff) | |
residual connection
Diffstat (limited to 'dl')
| -rw-r--r-- | dl/normalize/mnist_demo.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/dl/normalize/mnist_demo.py b/dl/normalize/mnist_demo.py new file mode 100644 index 0000000..dc9e00c --- /dev/null +++ b/dl/normalize/mnist_demo.py @@ -0,0 +1,33 @@ +from torchvision import transforms +import torchvision +import torch + + +# global, whole training dataset +# x' = (x-mean)/std +# x'*std + mean => x + +# timm.data.IMAGENET_DEFAULT_MEAN: (0.485, 0.456, 0.406) +# timm.data.IMAGENET_DEFAULT_STD: (0.229, 0.224, 0.225) +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.1307], std=[0.3081]) + ]) + +# MNIST dataset +mnist = torchvision.datasets.MNIST(root='./data/', + train=True, + transform=transform, + download=True) +batch_size = 32 +# Data loader +data_loader = torch.utils.data.DataLoader(dataset=mnist, + batch_size=batch_size, + shuffle=True) + +epochs = 10 + +for epoch in range(epochs): + + for i, (images, t) in enumerate(data_loader): + print(images)
\ No newline at end of file |
