blob: e9678211ad7817cf682e4e8ad2576e4f10ba0548 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms as T
import torch
training_dataset = datasets.FashionMNIST(root='./data', train=True, transform=T.ToTensor(), download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=T.ToTensor(), download=True)
print(training_dataset.classes)
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=4, shuffle=True, num_workers=0)
validation_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)
# next(iter(training_loader))
for i, data in enumerate(training_loader):
batch_images, batch_labels = data
break
|