diff options
| author | zhang <zch921005@126.com> | 2023-02-22 00:22:36 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2023-02-22 00:22:36 +0800 |
| commit | 12ac2c65cb05b15461592333e338feaf98cbe7cb (patch) | |
| tree | 2245481b2b85dff6643d207d05a1cc32ec572047 /dl/tutorials/utils.py | |
| parent | 37506402a98eba9bf9d06760a1010fa17adb39e4 (diff) | |
vgg
Diffstat (limited to 'dl/tutorials/utils.py')
| -rw-r--r-- | dl/tutorials/utils.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/dl/tutorials/utils.py b/dl/tutorials/utils.py new file mode 100644 index 0000000..8368e94 --- /dev/null +++ b/dl/tutorials/utils.py @@ -0,0 +1,18 @@ +import torch +from tqdm import tqdm + + +def get_mean_and_std(dataset): + '''Compute the mean and std value of dataset.''' + '''dataset: 0-1 range (ToTensor())''' + dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) + mean = torch.zeros(3) + std = torch.zeros(3) + print('==> Computing mean and std..') + for inputs, targets in tqdm(dataloader): + for i in range(3): + mean[i] += inputs[:, i, :, :].mean() + std[i] += inputs[:, i, :, :].std() + mean.div_(len(dataset)) + std.div_(len(dataset)) + return mean, std |
