diff options
Diffstat (limited to 'ml_core/tutorials')
| -rw-r--r-- | ml_core/tutorials/02_pytorch_mle_for_bernoulli.ipynb | 390 |
1 files changed, 390 insertions, 0 deletions
diff --git a/ml_core/tutorials/02_pytorch_mle_for_bernoulli.ipynb b/ml_core/tutorials/02_pytorch_mle_for_bernoulli.ipynb new file mode 100644 index 0000000..6fc6ff6 --- /dev/null +++ b/ml_core/tutorials/02_pytorch_mle_for_bernoulli.ipynb @@ -0,0 +1,390 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "169a2931", + "metadata": {}, + "source": [ + "## mle for Bernoulli" + ] + }, + { + "cell_type": "markdown", + "id": "c4ec298a", + "metadata": {}, + "source": [ + "- maximum likelihood == minimize Cross Entropy Loss\n", + " - 最大似然估计,估计的是参数(分布的参数),以伯努利分布为例\n", + " - $\\{x_i\\}_{1,\\cdots,N}\\sim B(p)$(采样或者观测)\n", + " - $p$ 的概率为1,$1-p$的概率为0,$x_i\\in\\{0,1\\}$\n", + " - 此时需要基于 $\\{x_i\\}_{1,\\cdots,N}$ 来估计 $p$\n", + " - $\\ell(p)=\\Pi_{i=1}^Np^{x_i}(1-p)^{1-x_i}$(joint probability)\n", + "\n", + "$$\n", + "\\begin{split}\n", + "\\log\\ell(p)&=\\sum_{i=1}^N\\log(p^{x_i}(1-p)^{1-x_i})\\\\\n", + "&=\\sum_{i=1}^Nx_i\\log(p)+(1-x_i)\\log(1-p)\n", + "\\end{split}\n", + "$$\n", + "\n", + "\n", + "- 导数,求极值\n", + " - $\\frac{\\partial \\log\\ell(p)}{\\partial p}=\\frac{\\sum_ix_i}{p}-\\frac{\\sum_i{(1-x_i)}}{1-p}\\overset{\\text{set}}{=}0$\n", + " - $p=\\frac{\\sum_ix_i}{N}$(样本均值,sample mean)\n", + " - 极大值还是极小值,可以求二阶导:\n", + " - $\\dfrac{\\partial^2 \\ell(p)}{\\partial p^2} = \\dfrac{-\\sum_{i=1}^n x_i}{p^2} - \\dfrac{\\sum_{i=1}^n (1-x_i)}{(1-p)^2}$\n", + " - 二阶导小于0,为极大值(参考 $-x^2$);" + ] + }, + { + "cell_type": "markdown", + "id": "e916340d", + "metadata": {}, + "source": [ + "## pytorch 梯度下降" + ] + }, + { + "cell_type": "markdown", + "id": "ce2498ae", + "metadata": {}, + "source": [ + "### sample from Bernoulli distribution" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "eb312f4c", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:03:31.009959Z", + "start_time": "2023-04-03T14:03:31.007350Z" + } + }, + "outputs": [], + "source": [ + "from scipy import stats\n", + "import numpy as np\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "6d6480b9", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:03:31.962431Z", + "start_time": "2023-04-03T14:03:31.958478Z" + } + }, + "outputs": [], + "source": [ + "p = 0.43\n", + "dist = stats.bernoulli(p)\n", + "xs = dist.rvs(3000)" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "2d656a36", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:03:32.994882Z", + "start_time": "2023-04-03T14:03:32.990429Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 1, ..., 0, 1, 0])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xs" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "1b3ff299", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:03:48.706449Z", + "start_time": "2023-04-03T14:03:48.696408Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.42733333333333334" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# sample mean\n", + "np.mean(xs)" + ] + }, + { + "cell_type": "markdown", + "id": "e4382e45", + "metadata": {}, + "source": [ + "### mle by pytorch gradient descent" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "fa2fa633", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:04:11.838928Z", + "start_time": "2023-04-03T14:04:11.835809Z" + } + }, + "outputs": [], + "source": [ + "xs_tensor = torch.from_numpy(xs).type(torch.FloatTensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "ad7160f4", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:04:13.208113Z", + "start_time": "2023-04-03T14:04:13.203274Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0., 1., 1., ..., 0., 1., 0.])" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xs_tensor" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "abcf6e42", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:04:26.504174Z", + "start_time": "2023-04-03T14:04:26.499604Z" + } + }, + "outputs": [], + "source": [ + "p_est = torch.rand(1, requires_grad=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "09aba32b", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:04:27.855980Z", + "start_time": "2023-04-03T14:04:27.847966Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.2782], requires_grad=True)" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p_est" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "c005aaf6", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:04:38.955960Z", + "start_time": "2023-04-03T14:04:38.953107Z" + } + }, + "outputs": [], + "source": [ + "lr = 2e-5" + ] + }, + { + "cell_type": "markdown", + "id": "24bdde16", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{split}\n", + "\\log\\ell(p)&=\\sum_{i=1}^N\\log(p^{x_i}(1-p)^{1-x_i})\\\\\n", + "&=\\sum_{i=1}^Nx_i\\log(p)+(1-x_i)\\log(1-p)\n", + "\\end{split}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "a773903c", + "metadata": {}, + "source": [ + "$$\n", + "\\frac{\\partial \\log\\ell(p)}{\\partial p}=\\frac{\\sum_ix_i}{p}-\\frac{\\sum_i{(1-x_i)}}{1-p}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "352d6792", + "metadata": { + "ExecuteTime": { + "end_time": "2023-04-03T14:08:17.425731Z", + "start_time": "2023-04-03T14:08:17.354743Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([0.5192], requires_grad=True)\n", + "p_est:[0.5192194], NLL:2098.42724609375, dL/dp: [1104.2644]\n", + "\t tensor([-1104.2644])\n", + "p_est:[0.4505255], NLL:2050.916748046875, dL/dp: [281.05786]\n", + "\t tensor([-281.0579])\n", + "p_est:[0.4330856], NLL:2047.8486328125, dL/dp: [70.28589]\n", + "\t tensor([-70.2859])\n", + "p_est:[0.4287475], NLL:2047.658935546875, dL/dp: [17.321777]\n", + "\t tensor([-17.3218])\n", + "p_est:[0.42768013], NLL:2047.6468505859375, dL/dp: [4.2504883]\n", + "\t tensor([-4.2505])\n", + "p_est:[0.42741832], NLL:2047.646728515625, dL/dp: [1.0419922]\n", + "\t tensor([-1.0420])\n", + "p_est:[0.42735416], NLL:2047.646240234375, dL/dp: [0.2553711]\n", + "\t tensor([-0.2554])\n", + "p_est:[0.42733842], NLL:2047.646484375, dL/dp: [0.06225586]\n", + "\t tensor([-0.0623])\n", + "p_est:[0.42733458], NLL:2047.6463623046875, dL/dp: [0.01513672]\n", + "\t tensor([-0.0151])\n", + "p_est:[0.42733365], NLL:2047.646484375, dL/dp: [0.00390625]\n", + "\t tensor([-0.0039])\n", + "p_est:[0.4273334], NLL:2047.6463623046875, dL/dp: [0.00097656]\n", + "\t tensor([-0.0010])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n", + "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n", + "\t tensor([-0.0002])\n" + ] + } + ], + "source": [ + "p_est = torch.rand(1, requires_grad=True)\n", + "print(p_est)\n", + "lr = 2e-5\n", + "for epoch in range(100):\n", + " # NLL: negative log likelihood\n", + " # minimize NLL\n", + " NLL = -torch.sum(xs_tensor * torch.log(p_est) + (1-xs_tensor)*torch.log(1-p_est))\n", + " NLL.backward()\n", + " \n", + " if epoch % 5 == 0:\n", + " print(f'p_est:{p_est.data.numpy()}, NLL:{NLL.data.numpy()}, dL/dp: {p_est.grad.data.numpy()}')\n", + " print('\\t', torch.sum(xs_tensor)/p_est.data.numpy() - torch.sum(1-xs_tensor)/(1-p_est.data.numpy()))\n", + " \n", + " p_est.data = p_est.data - lr*p_est.grad.data\n", + " p_est.grad.data.zero_()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22b4d334", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
