{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:12:41.751525Z", "start_time": "2022-09-13T15:12:39.748638Z" } }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import numpy as np\n", "from copy import deepcopy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. module" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:12:43.503523Z", "start_time": "2022-09-13T15:12:43.498924Z" } }, "outputs": [], "source": [ "m = nn.BatchNorm1d(3)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:12:46.134859Z", "start_time": "2022-09-13T15:12:46.117802Z" } }, "outputs": [ { "data": { "text/plain": [ "BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 m(x1)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:14:00.620025Z", "start_time": "2022-09-13T15:14:00.617068Z" } }, "outputs": [], "source": [ "x1 = torch.randint(0, 5, (2, 3), dtype=torch.float)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:14:04.022933Z", "start_time": "2022-09-13T15:14:04.016599Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[2., 3., 1.],\n", " [1., 3., 4.]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:15:03.155867Z", "start_time": "2022-09-13T15:15:03.149579Z" } }, "outputs": [ { "data": { "text/plain": [ "(tensor([1.5000, 3.0000, 2.5000]), tensor([0.2500, 0.0000, 2.2500]))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1.mean(dim=0), x1.var(dim=0, unbiased=False)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:16:50.746057Z", "start_time": "2022-09-13T15:16:50.740149Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 1.0000, 0.0000, -1.0000],\n", " [-1.0000, 0.0000, 1.0000]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# biased (unbiased = False)\n", "(x1 - x1.mean(dim=0))/torch.sqrt(x1.var(dim=0, unbiased=False) + 1e-5)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:15:14.087721Z", "start_time": "2022-09-13T15:15:14.080391Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 1.0000, 0.0000, -1.0000],\n", " [-1.0000, 0.0000, 1.0000]], grad_fn=)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m(x1)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:15:40.407708Z", "start_time": "2022-09-13T15:15:40.404523Z" } }, "outputs": [], "source": [ "last_mean, last_var = deepcopy(m.running_mean), deepcopy(m.running_var)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:15:43.888563Z", "start_time": "2022-09-13T15:15:43.883410Z" } }, "outputs": [ { "data": { "text/plain": [ "(tensor([0.1500, 0.3000, 0.2500]), tensor([0.9500, 0.9000, 1.3500]))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "last_mean, last_var" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:16:11.611791Z", "start_time": "2022-09-13T15:16:11.606570Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.1500, 0.3000, 0.2500])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(1-0.1)*0 + 0.1*x1.mean(dim=0)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:16:58.637472Z", "start_time": "2022-09-13T15:16:58.632528Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.9500, 0.9000, 1.3500])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# unbiased = True\n", "(1-0.1)*torch.ones(3) + 0.1*x1.var(dim=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 m(x2)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:17:11.997946Z", "start_time": "2022-09-13T15:17:11.995243Z" } }, "outputs": [], "source": [ "x2 = torch.randint(0, 5, (2, 3), dtype=torch.float)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:17:12.954075Z", "start_time": "2022-09-13T15:17:12.949510Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[0., 3., 0.],\n", " [3., 2., 2.]])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:17:21.705867Z", "start_time": "2022-09-13T15:17:21.701285Z" } }, "outputs": [ { "data": { "text/plain": [ "(tensor([1.5000, 2.5000, 1.0000]), tensor([4.5000, 0.5000, 2.0000]))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2.mean(dim=0), x2.var(dim=0)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:17:26.797073Z", "start_time": "2022-09-13T15:17:26.791778Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.0000, 1.0000, -1.0000],\n", " [ 1.0000, -1.0000, 1.0000]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x2 - x2.mean(dim=0)) / torch.sqrt(x2.var(dim=0, unbiased=False)+1e-05)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:17:29.424105Z", "start_time": "2022-09-13T15:17:29.418592Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.0000, 1.0000, -1.0000],\n", " [ 1.0000, -1.0000, 1.0000]], grad_fn=)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m(x2)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:17:49.273927Z", "start_time": "2022-09-13T15:17:49.268708Z" } }, "outputs": [ { "data": { "text/plain": [ "(tensor([0.2850, 0.5200, 0.3250]), tensor([1.3050, 0.8600, 1.4150]))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.running_mean, m.running_var" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:18:03.331985Z", "start_time": "2022-09-13T15:18:03.326575Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.2850, 0.5200, 0.3250])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(1-0.1)*last_mean + 0.1*x2.mean(dim=0)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:18:07.494252Z", "start_time": "2022-09-13T15:18:07.487036Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([1.3050, 0.8600, 1.4150])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(1-0.1)*last_var + 0.1*x2.var(dim=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. eval mode" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:18:18.641717Z", "start_time": "2022-09-13T15:18:18.639009Z" } }, "outputs": [], "source": [ "x3 = torch.randint(0, 5, (2, 3), dtype=torch.float)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:18:20.564723Z", "start_time": "2022-09-13T15:18:20.560541Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[1., 3., 3.],\n", " [2., 0., 3.]])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x3" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:18:23.072850Z", "start_time": "2022-09-13T15:18:23.069084Z" } }, "outputs": [ { "data": { "text/plain": [ "BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.eval()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:18:32.993600Z", "start_time": "2022-09-13T15:18:32.987854Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.6259, 2.6742, 2.2488],\n", " [ 1.5013, -0.5607, 2.2488]], grad_fn=)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m(x3)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:18:41.861910Z", "start_time": "2022-09-13T15:18:41.856128Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.0000, 1.0000, 0.0000],\n", " [ 1.0000, -1.0000, 0.0000]])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x3 - x3.mean(dim=0))/torch.sqrt(x3.var(dim=0, unbiased=False) + 1e-5)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "ExecuteTime": { "end_time": "2022-09-13T15:19:06.793626Z", "start_time": "2022-09-13T15:19:06.788040Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.6259, 2.6742, 2.2488],\n", " [ 1.5013, -0.5607, 2.2488]])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x3 - m.running_mean)/torch.sqrt(m.running_var+1e-5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }