import torch from tqdm import tqdm def get_mean_and_std(dataset): '''Compute the mean and std value of dataset.''' '''dataset: 0-1 range (ToTensor())''' dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) mean = torch.zeros(3) std = torch.zeros(3) print('==> Computing mean and std..') for inputs, targets in tqdm(dataloader): for i in range(3): mean[i] += inputs[:, i, :, :].mean() std[i] += inputs[:, i, :, :].std() mean.div_(len(dataset)) std.div_(len(dataset)) return mean, std