summaryrefslogtreecommitdiff
path: root/dl/tutorials/01_torch_dataset_dataloader.py
blob: 18ad5ad586d328c08c33abfc182c30f60be610d0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch import nn
import torchvision
from torchvision import transforms


train_dataset = torchvision.datasets.FashionMNIST(root='../data',
                                           train=True,
                                           download=True,
                                           transform=transforms.ToTensor())
test_dataset = torchvision.datasets.FashionMNIST(root='../data',
                                          train=False,
                                          download=True,
                                          transform=transforms.ToTensor())

train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

for batch_index, (images, labels) in enumerate(train_dataloader):
    if batch_index == len(train_dataset)//64:
        print(images.shape, labels.shape)