summaryrefslogtreecommitdiff
path: root/dl/tutorials/utils.py
blob: 8368e9496e78594188cae97e691d6c7435805a6a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from tqdm import tqdm


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    '''dataset: 0-1 range (ToTensor())'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in tqdm(dataloader):
        for i in range(3):
            mean[i] += inputs[:, i, :, :].mean()
            std[i] += inputs[:, i, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std