blob: 8368e9496e78594188cae97e691d6c7435805a6a (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
import torch
from tqdm import tqdm
def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.'''
'''dataset: 0-1 range (ToTensor())'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print('==> Computing mean and std..')
for inputs, targets in tqdm(dataloader):
for i in range(3):
mean[i] += inputs[:, i, :, :].mean()
std[i] += inputs[:, i, :, :].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean, std
|