blob: 18ad5ad586d328c08c33abfc182c30f60be610d0 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
import torch
from torch import nn
import torchvision
from torchvision import transforms
train_dataset = torchvision.datasets.FashionMNIST(root='../data',
train=True,
download=True,
transform=transforms.ToTensor())
test_dataset = torchvision.datasets.FashionMNIST(root='../data',
train=False,
download=True,
transform=transforms.ToTensor())
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
for batch_index, (images, labels) in enumerate(train_dataloader):
if batch_index == len(train_dataset)//64:
print(images.shape, labels.shape)
|