summaryrefslogtreecommitdiff
path: root/ml_core/tutorials
diff options
context:
space:
mode:
authorchzhang <zch921005@126.com>2023-04-03 22:14:54 +0800
committerchzhang <zch921005@126.com>2023-04-03 22:14:54 +0800
commit2d07469f62b5bc12528d7b8dce6cf7fcb2459325 (patch)
tree72a205f406401dff03e6a82f0270eb756fe405eb /ml_core/tutorials
parent2c5dc53a4c2d64866a3b567cf971a9e9c335923f (diff)
pytorch mle bernoulli
Diffstat (limited to 'ml_core/tutorials')
-rw-r--r--ml_core/tutorials/02_pytorch_mle_for_bernoulli.ipynb390
1 files changed, 390 insertions, 0 deletions
diff --git a/ml_core/tutorials/02_pytorch_mle_for_bernoulli.ipynb b/ml_core/tutorials/02_pytorch_mle_for_bernoulli.ipynb
new file mode 100644
index 0000000..6fc6ff6
--- /dev/null
+++ b/ml_core/tutorials/02_pytorch_mle_for_bernoulli.ipynb
@@ -0,0 +1,390 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "169a2931",
+ "metadata": {},
+ "source": [
+ "## mle for Bernoulli"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c4ec298a",
+ "metadata": {},
+ "source": [
+ "- maximum likelihood == minimize Cross Entropy Loss\n",
+ " - 最大似然估计,估计的是参数(分布的参数),以伯努利分布为例\n",
+ " - $\\{x_i\\}_{1,\\cdots,N}\\sim B(p)$(采样或者观测)\n",
+ " - $p$ 的概率为1,$1-p$的概率为0,$x_i\\in\\{0,1\\}$\n",
+ " - 此时需要基于 $\\{x_i\\}_{1,\\cdots,N}$ 来估计 $p$\n",
+ " - $\\ell(p)=\\Pi_{i=1}^Np^{x_i}(1-p)^{1-x_i}$(joint probability)\n",
+ "\n",
+ "$$\n",
+ "\\begin{split}\n",
+ "\\log\\ell(p)&=\\sum_{i=1}^N\\log(p^{x_i}(1-p)^{1-x_i})\\\\\n",
+ "&=\\sum_{i=1}^Nx_i\\log(p)+(1-x_i)\\log(1-p)\n",
+ "\\end{split}\n",
+ "$$\n",
+ "\n",
+ "\n",
+ "- 导数,求极值\n",
+ " - $\\frac{\\partial \\log\\ell(p)}{\\partial p}=\\frac{\\sum_ix_i}{p}-\\frac{\\sum_i{(1-x_i)}}{1-p}\\overset{\\text{set}}{=}0$\n",
+ " - $p=\\frac{\\sum_ix_i}{N}$(样本均值,sample mean)\n",
+ " - 极大值还是极小值,可以求二阶导:\n",
+ " - $\\dfrac{\\partial^2 \\ell(p)}{\\partial p^2} = \\dfrac{-\\sum_{i=1}^n x_i}{p^2} - \\dfrac{\\sum_{i=1}^n (1-x_i)}{(1-p)^2}$\n",
+ " - 二阶导小于0,为极大值(参考 $-x^2$);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e916340d",
+ "metadata": {},
+ "source": [
+ "## pytorch 梯度下降"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ce2498ae",
+ "metadata": {},
+ "source": [
+ "### sample from Bernoulli distribution"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "eb312f4c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:03:31.009959Z",
+ "start_time": "2023-04-03T14:03:31.007350Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from scipy import stats\n",
+ "import numpy as np\n",
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "6d6480b9",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:03:31.962431Z",
+ "start_time": "2023-04-03T14:03:31.958478Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "p = 0.43\n",
+ "dist = stats.bernoulli(p)\n",
+ "xs = dist.rvs(3000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "id": "2d656a36",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:03:32.994882Z",
+ "start_time": "2023-04-03T14:03:32.990429Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([0, 1, 1, ..., 0, 1, 0])"
+ ]
+ },
+ "execution_count": 43,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "id": "1b3ff299",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:03:48.706449Z",
+ "start_time": "2023-04-03T14:03:48.696408Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.42733333333333334"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# sample mean\n",
+ "np.mean(xs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e4382e45",
+ "metadata": {},
+ "source": [
+ "### mle by pytorch gradient descent"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "id": "fa2fa633",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:04:11.838928Z",
+ "start_time": "2023-04-03T14:04:11.835809Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "xs_tensor = torch.from_numpy(xs).type(torch.FloatTensor)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "id": "ad7160f4",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:04:13.208113Z",
+ "start_time": "2023-04-03T14:04:13.203274Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0., 1., 1., ..., 0., 1., 0.])"
+ ]
+ },
+ "execution_count": 46,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "xs_tensor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "id": "abcf6e42",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:04:26.504174Z",
+ "start_time": "2023-04-03T14:04:26.499604Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "p_est = torch.rand(1, requires_grad=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "09aba32b",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:04:27.855980Z",
+ "start_time": "2023-04-03T14:04:27.847966Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.2782], requires_grad=True)"
+ ]
+ },
+ "execution_count": 48,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "p_est"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "c005aaf6",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:04:38.955960Z",
+ "start_time": "2023-04-03T14:04:38.953107Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "lr = 2e-5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "24bdde16",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\begin{split}\n",
+ "\\log\\ell(p)&=\\sum_{i=1}^N\\log(p^{x_i}(1-p)^{1-x_i})\\\\\n",
+ "&=\\sum_{i=1}^Nx_i\\log(p)+(1-x_i)\\log(1-p)\n",
+ "\\end{split}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a773903c",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\frac{\\partial \\log\\ell(p)}{\\partial p}=\\frac{\\sum_ix_i}{p}-\\frac{\\sum_i{(1-x_i)}}{1-p}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "id": "352d6792",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-04-03T14:08:17.425731Z",
+ "start_time": "2023-04-03T14:08:17.354743Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([0.5192], requires_grad=True)\n",
+ "p_est:[0.5192194], NLL:2098.42724609375, dL/dp: [1104.2644]\n",
+ "\t tensor([-1104.2644])\n",
+ "p_est:[0.4505255], NLL:2050.916748046875, dL/dp: [281.05786]\n",
+ "\t tensor([-281.0579])\n",
+ "p_est:[0.4330856], NLL:2047.8486328125, dL/dp: [70.28589]\n",
+ "\t tensor([-70.2859])\n",
+ "p_est:[0.4287475], NLL:2047.658935546875, dL/dp: [17.321777]\n",
+ "\t tensor([-17.3218])\n",
+ "p_est:[0.42768013], NLL:2047.6468505859375, dL/dp: [4.2504883]\n",
+ "\t tensor([-4.2505])\n",
+ "p_est:[0.42741832], NLL:2047.646728515625, dL/dp: [1.0419922]\n",
+ "\t tensor([-1.0420])\n",
+ "p_est:[0.42735416], NLL:2047.646240234375, dL/dp: [0.2553711]\n",
+ "\t tensor([-0.2554])\n",
+ "p_est:[0.42733842], NLL:2047.646484375, dL/dp: [0.06225586]\n",
+ "\t tensor([-0.0623])\n",
+ "p_est:[0.42733458], NLL:2047.6463623046875, dL/dp: [0.01513672]\n",
+ "\t tensor([-0.0151])\n",
+ "p_est:[0.42733365], NLL:2047.646484375, dL/dp: [0.00390625]\n",
+ "\t tensor([-0.0039])\n",
+ "p_est:[0.4273334], NLL:2047.6463623046875, dL/dp: [0.00097656]\n",
+ "\t tensor([-0.0010])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n",
+ "p_est:[0.42733338], NLL:2047.646240234375, dL/dp: [0.00024414]\n",
+ "\t tensor([-0.0002])\n"
+ ]
+ }
+ ],
+ "source": [
+ "p_est = torch.rand(1, requires_grad=True)\n",
+ "print(p_est)\n",
+ "lr = 2e-5\n",
+ "for epoch in range(100):\n",
+ " # NLL: negative log likelihood\n",
+ " # minimize NLL\n",
+ " NLL = -torch.sum(xs_tensor * torch.log(p_est) + (1-xs_tensor)*torch.log(1-p_est))\n",
+ " NLL.backward()\n",
+ " \n",
+ " if epoch % 5 == 0:\n",
+ " print(f'p_est:{p_est.data.numpy()}, NLL:{NLL.data.numpy()}, dL/dp: {p_est.grad.data.numpy()}')\n",
+ " print('\\t', torch.sum(xs_tensor)/p_est.data.numpy() - torch.sum(1-xs_tensor)/(1-p_est.data.numpy()))\n",
+ " \n",
+ " p_est.data = p_est.data - lr*p_est.grad.data\n",
+ " p_est.grad.data.zero_()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "22b4d334",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}