diff options
Diffstat (limited to 'dl/tutorials/utils.py')
| -rw-r--r-- | dl/tutorials/utils.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/dl/tutorials/utils.py b/dl/tutorials/utils.py new file mode 100644 index 0000000..8368e94 --- /dev/null +++ b/dl/tutorials/utils.py @@ -0,0 +1,18 @@ +import torch +from tqdm import tqdm + + +def get_mean_and_std(dataset): + '''Compute the mean and std value of dataset.''' + '''dataset: 0-1 range (ToTensor())''' + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) + mean = torch.zeros(3) + std = torch.zeros(3) + print('==> Computing mean and std..') + for inputs, targets in tqdm(dataloader): + for i in range(3): + mean[i] += inputs[:, i, :, :].mean() + std[i] += inputs[:, i, :, :].std() + mean.div_(len(dataset)) + std.div_(len(dataset)) + return mean, std |
