summaryrefslogtreecommitdiff
path: root/dl/bn/np_demo.py
blob: 9cd6296702937c0b4fcdd7b41ca2cd46c0b274fe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

import numpy as np
import matplotlib.pyplot as plt


batch = np.zeros([2, 3, 3, 3])
batch[0, :, :, 0] = np.asarray([[2.4, 0.3, 5.6],
                                [7.2, 0.0, 2.1],
                                [1.9, 7.4, 4.1]])
batch[0, :, :, 1] = np.asarray([[1.0, 2.1, 3.8],
                                [9.6, 2.0, 4.4],
                                [9.3, 3.9, 1.4]])
batch[0, :, :, 2] = np.asarray([[2.6, 8.1, 9.0],
                                [2.2, 6.6, 0.2],
                                [6.3, 6.0, 0.1]])

batch[1, :, :, 0] = np.asarray([[5.1, 9.0, 6.9],
                                [9.7, 2.9, 2.4],
                                [6.2, 0.7, 4.3]])
batch[1, :, :, 1] = np.asarray([[3.3, 7.9, 4.8],
                                [3.7, 2.0, 9.4],
                                [4.9, 2.1, 7.8]])

batch[1, :, :, 2] = np.asarray([[2.8, 5.4, 0.0],
                                [6.4, 1.8, 6.8],
                                [3.7, 5.6, 5.6]])
print(batch[:, :, :, 0].mean())
print(batch[:, :, :, 0].var())