summaryrefslogtreecommitdiff
path: root/experiments/cifar_resmlp.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/cifar_resmlp.py')
-rw-r--r--experiments/cifar_resmlp.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
index 05a355d..435b484 100644
--- a/experiments/cifar_resmlp.py
+++ b/experiments/cifar_resmlp.py
@@ -47,6 +47,21 @@ def get_data(dataset='cifar10', batch_size=128):
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
input_dim = 32 * 32 * 3
num_classes = 10
+ elif dataset == 'cifar100':
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
+ input_dim = 32 * 32 * 3
+ num_classes = 100
elif dataset == 'fashionmnist':
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=2),