diff options
Diffstat (limited to 'learn_torch/loss')
| -rw-r--r-- | learn_torch/loss/02_kl_divergence_cross_entropy.ipynb | 601 |
1 files changed, 601 insertions, 0 deletions
diff --git a/learn_torch/loss/02_kl_divergence_cross_entropy.ipynb b/learn_torch/loss/02_kl_divergence_cross_entropy.ipynb new file mode 100644 index 0000000..c30c932 --- /dev/null +++ b/learn_torch/loss/02_kl_divergence_cross_entropy.ipynb @@ -0,0 +1,601 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "352af752", + "metadata": {}, + "source": [ + "## kl divergence & cross entropy" + ] + }, + { + "cell_type": "markdown", + "id": "e74d2853", + "metadata": {}, + "source": [ + "### 定义及计算" + ] + }, + { + "cell_type": "markdown", + "id": "6c0bdd63", + "metadata": {}, + "source": [ + "- entropy:\n", + "$$\n", + "H(p)=-\\sum_{x\\in \\mathcal X}p(x)\\log p(x)=\\sum_{x\\in\\mathcal X}p(x)\\log\\frac1{p(x)}\n", + "$$\n", + "\n", + "- cross entropy(交叉熵):\n", + "$$\n", + "H(p,q)=-\\sum_{x\\in\\mathcal X}p(x)\\log q(x)=\\sum_{x\\in\\mathcal X}p(x)\\log\\frac1{q(x)}\n", + "$$\n", + "\n", + "- kl divergence(相对熵):\n", + "$$\n", + "\\begin{align}\n", + "D_\\text{KL}(p\\parallel q) &=-\\sum_{x\\in\\mathcal X}p(x)\\log\\frac{q(x)}{p(x)}\\\\\n", + "&=-\\sum_{x\\in \\mathcal X}\\left(p(x)\\log q(x)-p(x)\\log p(x)\\right)\\\\\n", + "&=-\\sum_{x\\mathcal X}p(x)\\log q(x)-\\left(-\\sum_{x\\mathcal X}p(x)\\log p(x)\\right)\\\\\n", + "&=H(p,q)-H(p)\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "58230a01", + "metadata": {}, + "source": [ + "### 性质" + ] + }, + { + "cell_type": "markdown", + "id": "7a1a9981", + "metadata": {}, + "source": [ + "- $H(p)\\geq 0$\n", + "- $H(p,q)\\geq 0$\n", + " - 非对称\n", + " - $H(p,p)=H(p)$\n", + "- $D_{\\text{KL}}(p\\parallel q)\\geq 0$(gibbs inequality)\n", + " - $H(p,q)\\geq H(p)$\n", + " - $D_{\\text{KL}} = 0 $ 当 $p == q$ 时;\n", + " - 非对称,$D_{\\text{KL}}(p\\parallel q) \\neq D_{\\text{KL}}(q\\parallel p)$\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "a30c0514", + "metadata": {}, + "source": [ + "## pytorch api" + ] + }, + { + "cell_type": "markdown", + "id": "c6e48572", + "metadata": {}, + "source": [ + "- $p$: true(target) distribution,label值;\n", + " - $p(y|x)$:one hot,一种特殊的概率分布\n", + "- $q$: predict(input) distribution,模型的预测值(分布);\n", + " - $q(y|x)$:dense prob distribution\n", + "- $H(p,q)=-\\sum p(x)\\log q(x)$ (退化成 log loss)\n", + " - $H(q,p)=-\\sum q(x)\\log p(x)$" + ] + }, + { + "cell_type": "markdown", + "id": "acf57c28", + "metadata": {}, + "source": [ + "### wiki" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "bdda1223", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:32:54.946377Z", + "start_time": "2023-03-20T15:32:54.941360Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "21531b41", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:33:10.712094Z", + "start_time": "2023-03-20T15:33:10.705862Z" + } + }, + "outputs": [], + "source": [ + "kl_loss = nn.KLDivLoss(reduction=\"batchmean\")" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "d67a2c13", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:33:39.814970Z", + "start_time": "2023-03-20T15:33:39.805253Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 3]) tensor([[-1.0986, -1.0986, -1.0986]])\n", + "torch.Size([1, 3])\n" + ] + } + ], + "source": [ + "input = torch.log(torch.FloatTensor([[1/3, 1/3, 1/3]]))\n", + "print(input.shape, input)\n", + "target = torch.FloatTensor([[9/25, 12/25, 4/25]])\n", + "print(target.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "bb8b4638", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:33:40.891404Z", + "start_time": "2023-03-20T15:33:40.885385Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.0853)" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kl_loss(input, target)" + ] + }, + { + "cell_type": "markdown", + "id": "42044238", + "metadata": {}, + "source": [ + "### kl loss vs. ce loss" + ] + }, + { + "cell_type": "markdown", + "id": "72961f24", + "metadata": {}, + "source": [ + "- kl loss:\n", + " - `input.shape == target.shape`" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "ad5a5c32", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:34:47.656453Z", + "start_time": "2023-03-20T15:34:47.653460Z" + } + }, + "outputs": [], + "source": [ + "kl_loss = nn.KLDivLoss(reduction=\"batchmean\")\n", + "ce_loss = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "cdf1e537", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:34:56.714500Z", + "start_time": "2023-03-20T15:34:56.706054Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.3286, 0.4247, 1.5101, -0.4628, -0.6365],\n", + " [-0.1079, -0.6032, -1.1911, -0.1564, 0.9218],\n", + " [ 1.5526, -0.2029, 0.7564, -2.0878, -1.1310]], requires_grad=True)" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input = torch.randn(3, 5, requires_grad=True)\n", + "input" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "4c19e7ba", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:35:51.836659Z", + "start_time": "2023-03-20T15:35:51.829638Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0, 2, 0])" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "target = torch.empty(3, dtype=torch.long).random_(5)\n", + "target" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "fb11bab0", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:35:53.779703Z", + "start_time": "2023-03-20T15:35:53.771265Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.7296, grad_fn=<NllLossBackward0>)" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ce_value = ce_loss(input, target)\n", + "ce_value" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "3a8261e2", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:36:46.037218Z", + "start_time": "2023-03-20T15:36:46.031049Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-1.8236, -1.7275, -0.6421, -2.6150, -2.7887],\n", + " [-1.7406, -2.2359, -2.8238, -1.7891, -0.7109],\n", + " [-0.5414, -2.2969, -1.3376, -4.1818, -3.2250]],\n", + " grad_fn=<LogSoftmaxBackward0>)" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_log_softmax = F.log_softmax(input, dim=1)\n", + "input_log_softmax" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "79ee5e65", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:39:48.653273Z", + "start_time": "2023-03-20T15:39:48.649184Z" + } + }, + "outputs": [], + "source": [ + "target_shaped = torch.FloatTensor([[1, 0, 0, 0, 0], \n", + " [0, 0, 1, 0, 0], \n", + " [1, 0, 0, 0, 0]])" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "63306655", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:38:01.920546Z", + "start_time": "2023-03-20T15:38:01.915454Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.7296, grad_fn=<DivBackward0>)" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kl_loss(input_log_softmax, target_shaped)" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "de636da2", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:38:53.633383Z", + "start_time": "2023-03-20T15:38:53.623851Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.7296, grad_fn=<DivBackward1>)" + ] + }, + "execution_count": 63, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ce_loss(input, target_shaped)" + ] + }, + { + "cell_type": "markdown", + "id": "bc5f92ec", + "metadata": {}, + "source": [ + "- $H(p)==0$" + ] + }, + { + "cell_type": "markdown", + "id": "df51123e", + "metadata": {}, + "source": [ + "### $H(p)\\neq 0$" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "6fcf9969", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:40:59.553869Z", + "start_time": "2023-03-20T15:40:59.549352Z" + } + }, + "outputs": [], + "source": [ + "kl_loss = torch.nn.KLDivLoss(reduction='none')\n", + "ce_loss = torch.nn.CrossEntropyLoss(reduction='none')" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "1bb3351c", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:41:27.535234Z", + "start_time": "2023-03-20T15:41:27.531335Z" + } + }, + "outputs": [], + "source": [ + "# q(x)\n", + "input = torch.Tensor([[-0.1, 0.2, -0.4, 0.3]])\n", + "# p(x)\n", + "target = torch.Tensor([[-0.7, 0.1, -0.1, 0.1]])" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "4b86a65c", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:42:28.584559Z", + "start_time": "2023-03-20T15:42:28.574631Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-0.0635, 0.0116, 0.1097, -0.0190]])\n", + "tensor(0.0097)\n", + "tensor(0.0389)\n" + ] + } + ], + "source": [ + "# kl(p, q)\n", + "kl_output = kl_loss(F.log_softmax(input, dim=1), torch.softmax(target, dim=1))\n", + "print(kl_output)\n", + "print(kl_output.mean())\n", + "print(kl_output.sum())" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "12c3141f", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:42:54.035273Z", + "start_time": "2023-03-20T15:42:54.027528Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1.3832])" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# H(p, q)\n", + "ce_loss(input, torch.softmax(target, dim=1))" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "143110aa", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:43:20.831440Z", + "start_time": "2023-03-20T15:43:20.824927Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1.3443])" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# H(p) == H(p, p)\n", + "ce_loss(target, torch.softmax(target, dim=1))" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "73ab9d88", + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-20T15:46:07.758263Z", + "start_time": "2023-03-20T15:46:07.753794Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.038899999999999935" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# H(p, q) - H(p) == kl(p, q)\n", + "1.3832 - 1.3443" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f0068c3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |
