summaryrefslogtreecommitdiff
path: root/dl/tutorials/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'dl/tutorials/utils.py')
-rw-r--r--dl/tutorials/utils.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/dl/tutorials/utils.py b/dl/tutorials/utils.py
new file mode 100644
index 0000000..8368e94
--- /dev/null
+++ b/dl/tutorials/utils.py
@@ -0,0 +1,18 @@
+import torch
+from tqdm import tqdm
+
+
+def get_mean_and_std(dataset):
+ '''Compute the mean and std value of dataset.'''
+ '''dataset: 0-1 range (ToTensor())'''
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
+ mean = torch.zeros(3)
+ std = torch.zeros(3)
+ print('==> Computing mean and std..')
+ for inputs, targets in tqdm(dataloader):
+ for i in range(3):
+ mean[i] += inputs[:, i, :, :].mean()
+ std[i] += inputs[:, i, :, :].std()
+ mean.div_(len(dataset))
+ std.div_(len(dataset))
+ return mean, std