diff options
Diffstat (limited to 'experiments/cifar_resmlp.py')
| -rw-r--r-- | experiments/cifar_resmlp.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py index 05a355d..435b484 100644 --- a/experiments/cifar_resmlp.py +++ b/experiments/cifar_resmlp.py @@ -47,6 +47,21 @@ def get_data(dataset='cifar10', batch_size=128): testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) input_dim = 32 * 32 * 3 num_classes = 10 + elif dataset == 'cifar100': + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ]) + trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) + input_dim = 32 * 32 * 3 + num_classes = 100 elif dataset == 'fashionmnist': transform_train = transforms.Compose([ transforms.RandomCrop(28, padding=2), |
