summaryrefslogtreecommitdiff
path: root/projs/01-fashion-mnist/00_dataset_dataloader.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-07-31 13:07:02 +0800
committerzhang <zch921005@126.com>2022-07-31 13:07:02 +0800
commitfd4e40ae2ae58c06226cc9eb4c2ae9bdcfb677fd (patch)
treecba17079e81ba7ed99f818cfb3c2f30aceacf6f0 /projs/01-fashion-mnist/00_dataset_dataloader.py
parent92d3bc06bad13095df6515111bba45e73f701018 (diff)
wordpiece
Diffstat (limited to 'projs/01-fashion-mnist/00_dataset_dataloader.py')
-rw-r--r--projs/01-fashion-mnist/00_dataset_dataloader.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/projs/01-fashion-mnist/00_dataset_dataloader.py b/projs/01-fashion-mnist/00_dataset_dataloader.py
new file mode 100644
index 0000000..e967821
--- /dev/null
+++ b/projs/01-fashion-mnist/00_dataset_dataloader.py
@@ -0,0 +1,20 @@
+
+from torch.utils.data import Dataset
+from torchvision import datasets
+from torchvision import transforms as T
+import torch
+
+training_dataset = datasets.FashionMNIST(root='./data', train=True, transform=T.ToTensor(), download=True)
+test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=T.ToTensor(), download=True)
+
+
+print(training_dataset.classes)
+
+training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=4, shuffle=True, num_workers=0)
+validation_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)
+
+# next(iter(training_loader))
+
+for i, data in enumerate(training_loader):
+ batch_images, batch_labels = data
+ break