summaryrefslogtreecommitdiff
path: root/learn_torch/tutorials
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-09-13 23:21:25 +0800
committerzhang <zch921005@126.com>2022-09-13 23:21:25 +0800
commit2fef28a07fcc9f43455b24188987f882220c05f2 (patch)
tree75b3808c5daec81e25b4d27cee83432855deb8ad /learn_torch/tutorials
parentb7e7dbcfac5bf8907d7d1e06ee6de2597a4c80f0 (diff)
bn train vs. eval
Diffstat (limited to 'learn_torch/tutorials')
-rw-r--r--learn_torch/tutorials/bn_train_eval.ipynb659
1 files changed, 659 insertions, 0 deletions
diff --git a/learn_torch/tutorials/bn_train_eval.ipynb b/learn_torch/tutorials/bn_train_eval.ipynb
new file mode 100644
index 0000000..c450e74
--- /dev/null
+++ b/learn_torch/tutorials/bn_train_eval.ipynb
@@ -0,0 +1,659 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:12:41.751525Z",
+ "start_time": "2022-09-13T15:12:39.748638Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "import numpy as np\n",
+ "from copy import deepcopy"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1. module"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:12:43.503523Z",
+ "start_time": "2022-09-13T15:12:43.498924Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "m = nn.BatchNorm1d(3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:12:46.134859Z",
+ "start_time": "2022-09-13T15:12:46.117802Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "m"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.1 m(x1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:14:00.620025Z",
+ "start_time": "2022-09-13T15:14:00.617068Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "x1 = torch.randint(0, 5, (2, 3), dtype=torch.float)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:14:04.022933Z",
+ "start_time": "2022-09-13T15:14:04.016599Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[2., 3., 1.],\n",
+ " [1., 3., 4.]])"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:15:03.155867Z",
+ "start_time": "2022-09-13T15:15:03.149579Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(tensor([1.5000, 3.0000, 2.5000]), tensor([0.2500, 0.0000, 2.2500]))"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x1.mean(dim=0), x1.var(dim=0, unbiased=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:16:50.746057Z",
+ "start_time": "2022-09-13T15:16:50.740149Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 1.0000, 0.0000, -1.0000],\n",
+ " [-1.0000, 0.0000, 1.0000]])"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# biased (unbiased = False)\n",
+ "(x1 - x1.mean(dim=0))/torch.sqrt(x1.var(dim=0, unbiased=False) + 1e-5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:15:14.087721Z",
+ "start_time": "2022-09-13T15:15:14.080391Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 1.0000, 0.0000, -1.0000],\n",
+ " [-1.0000, 0.0000, 1.0000]], grad_fn=<NativeBatchNormBackward0>)"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "m(x1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:15:40.407708Z",
+ "start_time": "2022-09-13T15:15:40.404523Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "last_mean, last_var = deepcopy(m.running_mean), deepcopy(m.running_var)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:15:43.888563Z",
+ "start_time": "2022-09-13T15:15:43.883410Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(tensor([0.1500, 0.3000, 0.2500]), tensor([0.9500, 0.9000, 1.3500]))"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "last_mean, last_var"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:16:11.611791Z",
+ "start_time": "2022-09-13T15:16:11.606570Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.1500, 0.3000, 0.2500])"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(1-0.1)*0 + 0.1*x1.mean(dim=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:16:58.637472Z",
+ "start_time": "2022-09-13T15:16:58.632528Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.9500, 0.9000, 1.3500])"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# unbiased = True\n",
+ "(1-0.1)*torch.ones(3) + 0.1*x1.var(dim=0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.2 m(x2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:17:11.997946Z",
+ "start_time": "2022-09-13T15:17:11.995243Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "x2 = torch.randint(0, 5, (2, 3), dtype=torch.float)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:17:12.954075Z",
+ "start_time": "2022-09-13T15:17:12.949510Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[0., 3., 0.],\n",
+ " [3., 2., 2.]])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:17:21.705867Z",
+ "start_time": "2022-09-13T15:17:21.701285Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(tensor([1.5000, 2.5000, 1.0000]), tensor([4.5000, 0.5000, 2.0000]))"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x2.mean(dim=0), x2.var(dim=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:17:26.797073Z",
+ "start_time": "2022-09-13T15:17:26.791778Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[-1.0000, 1.0000, -1.0000],\n",
+ " [ 1.0000, -1.0000, 1.0000]])"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(x2 - x2.mean(dim=0)) / torch.sqrt(x2.var(dim=0, unbiased=False)+1e-05)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:17:29.424105Z",
+ "start_time": "2022-09-13T15:17:29.418592Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[-1.0000, 1.0000, -1.0000],\n",
+ " [ 1.0000, -1.0000, 1.0000]], grad_fn=<NativeBatchNormBackward0>)"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "m(x2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:17:49.273927Z",
+ "start_time": "2022-09-13T15:17:49.268708Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(tensor([0.2850, 0.5200, 0.3250]), tensor([1.3050, 0.8600, 1.4150]))"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "m.running_mean, m.running_var"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:18:03.331985Z",
+ "start_time": "2022-09-13T15:18:03.326575Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.2850, 0.5200, 0.3250])"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(1-0.1)*last_mean + 0.1*x2.mean(dim=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:18:07.494252Z",
+ "start_time": "2022-09-13T15:18:07.487036Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([1.3050, 0.8600, 1.4150])"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(1-0.1)*last_var + 0.1*x2.var(dim=0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 3. eval mode"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:18:18.641717Z",
+ "start_time": "2022-09-13T15:18:18.639009Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "x3 = torch.randint(0, 5, (2, 3), dtype=torch.float)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:18:20.564723Z",
+ "start_time": "2022-09-13T15:18:20.560541Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1., 3., 3.],\n",
+ " [2., 0., 3.]])"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:18:23.072850Z",
+ "start_time": "2022-09-13T15:18:23.069084Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "m.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:18:32.993600Z",
+ "start_time": "2022-09-13T15:18:32.987854Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 0.6259, 2.6742, 2.2488],\n",
+ " [ 1.5013, -0.5607, 2.2488]], grad_fn=<NativeBatchNormBackward0>)"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "m(x3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:18:41.861910Z",
+ "start_time": "2022-09-13T15:18:41.856128Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[-1.0000, 1.0000, 0.0000],\n",
+ " [ 1.0000, -1.0000, 0.0000]])"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(x3 - x3.mean(dim=0))/torch.sqrt(x3.var(dim=0, unbiased=False) + 1e-5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-13T15:19:06.793626Z",
+ "start_time": "2022-09-13T15:19:06.788040Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 0.6259, 2.6742, 2.2488],\n",
+ " [ 1.5013, -0.5607, 2.2488]])"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(x3 - m.running_mean)/torch.sqrt(m.running_var+1e-5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}