{ "cells": [ { "cell_type": "markdown", "id": "352af752", "metadata": {}, "source": [ "## kl divergence & cross entropy" ] }, { "cell_type": "markdown", "id": "e74d2853", "metadata": {}, "source": [ "### 定义及计算" ] }, { "cell_type": "markdown", "id": "6c0bdd63", "metadata": {}, "source": [ "- entropy:\n", "$$\n", "H(p)=-\\sum_{x\\in \\mathcal X}p(x)\\log p(x)=\\sum_{x\\in\\mathcal X}p(x)\\log\\frac1{p(x)}\n", "$$\n", "\n", "- cross entropy(交叉熵):\n", "$$\n", "H(p,q)=-\\sum_{x\\in\\mathcal X}p(x)\\log q(x)=\\sum_{x\\in\\mathcal X}p(x)\\log\\frac1{q(x)}\n", "$$\n", "\n", "- kl divergence(相对熵):\n", "$$\n", "\\begin{align}\n", "D_\\text{KL}(p\\parallel q) &=-\\sum_{x\\in\\mathcal X}p(x)\\log\\frac{q(x)}{p(x)}\\\\\n", "&=-\\sum_{x\\in \\mathcal X}\\left(p(x)\\log q(x)-p(x)\\log p(x)\\right)\\\\\n", "&=-\\sum_{x\\mathcal X}p(x)\\log q(x)-\\left(-\\sum_{x\\mathcal X}p(x)\\log p(x)\\right)\\\\\n", "&=H(p,q)-H(p)\n", "\\end{align}\n", "$$" ] }, { "cell_type": "markdown", "id": "58230a01", "metadata": {}, "source": [ "### 性质" ] }, { "cell_type": "markdown", "id": "7a1a9981", "metadata": {}, "source": [ "- $H(p)\\geq 0$\n", "- $H(p,q)\\geq 0$\n", " - 非对称\n", " - $H(p,p)=H(p)$\n", "- $D_{\\text{KL}}(p\\parallel q)\\geq 0$(gibbs inequality)\n", " - $H(p,q)\\geq H(p)$\n", " - $D_{\\text{KL}} = 0 $ 当 $p == q$ 时;\n", " - 非对称,$D_{\\text{KL}}(p\\parallel q) \\neq D_{\\text{KL}}(q\\parallel p)$\n", " " ] }, { "cell_type": "markdown", "id": "a30c0514", "metadata": {}, "source": [ "## pytorch api" ] }, { "cell_type": "markdown", "id": "c6e48572", "metadata": {}, "source": [ "- $p$: true(target) distribution,label值;\n", " - $p(y|x)$:one hot,一种特殊的概率分布\n", "- $q$: predict(input) distribution,模型的预测值(分布);\n", " - $q(y|x)$:dense prob distribution\n", "- $H(p,q)=-\\sum p(x)\\log q(x)$ (退化成 log loss)\n", " - $H(q,p)=-\\sum q(x)\\log p(x)$" ] }, { "cell_type": "markdown", "id": "acf57c28", "metadata": {}, "source": [ "### wiki" ] }, { "cell_type": "code", "execution_count": 49, "id": "bdda1223", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:32:54.946377Z", "start_time": "2023-03-20T15:32:54.941360Z" } }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 50, "id": "21531b41", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:33:10.712094Z", "start_time": "2023-03-20T15:33:10.705862Z" } }, "outputs": [], "source": [ "kl_loss = nn.KLDivLoss(reduction=\"batchmean\")" ] }, { "cell_type": "code", "execution_count": 52, "id": "d67a2c13", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:33:39.814970Z", "start_time": "2023-03-20T15:33:39.805253Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 3]) tensor([[-1.0986, -1.0986, -1.0986]])\n", "torch.Size([1, 3])\n" ] } ], "source": [ "input = torch.log(torch.FloatTensor([[1/3, 1/3, 1/3]]))\n", "print(input.shape, input)\n", "target = torch.FloatTensor([[9/25, 12/25, 4/25]])\n", "print(target.shape)" ] }, { "cell_type": "code", "execution_count": 53, "id": "bb8b4638", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:33:40.891404Z", "start_time": "2023-03-20T15:33:40.885385Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(0.0853)" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "kl_loss(input, target)" ] }, { "cell_type": "markdown", "id": "42044238", "metadata": {}, "source": [ "### kl loss vs. ce loss" ] }, { "cell_type": "markdown", "id": "72961f24", "metadata": {}, "source": [ "- kl loss:\n", " - `input.shape == target.shape`" ] }, { "cell_type": "code", "execution_count": 54, "id": "ad5a5c32", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:34:47.656453Z", "start_time": "2023-03-20T15:34:47.653460Z" } }, "outputs": [], "source": [ "kl_loss = nn.KLDivLoss(reduction=\"batchmean\")\n", "ce_loss = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 55, "id": "cdf1e537", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:34:56.714500Z", "start_time": "2023-03-20T15:34:56.706054Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 0.3286, 0.4247, 1.5101, -0.4628, -0.6365],\n", " [-0.1079, -0.6032, -1.1911, -0.1564, 0.9218],\n", " [ 1.5526, -0.2029, 0.7564, -2.0878, -1.1310]], requires_grad=True)" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input = torch.randn(3, 5, requires_grad=True)\n", "input" ] }, { "cell_type": "code", "execution_count": 57, "id": "4c19e7ba", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:35:51.836659Z", "start_time": "2023-03-20T15:35:51.829638Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0, 2, 0])" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target = torch.empty(3, dtype=torch.long).random_(5)\n", "target" ] }, { "cell_type": "code", "execution_count": 58, "id": "fb11bab0", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:35:53.779703Z", "start_time": "2023-03-20T15:35:53.771265Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(1.7296, grad_fn=)" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ce_value = ce_loss(input, target)\n", "ce_value" ] }, { "cell_type": "code", "execution_count": 60, "id": "3a8261e2", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:36:46.037218Z", "start_time": "2023-03-20T15:36:46.031049Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[-1.8236, -1.7275, -0.6421, -2.6150, -2.7887],\n", " [-1.7406, -2.2359, -2.8238, -1.7891, -0.7109],\n", " [-0.5414, -2.2969, -1.3376, -4.1818, -3.2250]],\n", " grad_fn=)" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_log_softmax = F.log_softmax(input, dim=1)\n", "input_log_softmax" ] }, { "cell_type": "code", "execution_count": 66, "id": "79ee5e65", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:39:48.653273Z", "start_time": "2023-03-20T15:39:48.649184Z" } }, "outputs": [], "source": [ "target_shaped = torch.FloatTensor([[1, 0, 0, 0, 0], \n", " [0, 0, 1, 0, 0], \n", " [1, 0, 0, 0, 0]])" ] }, { "cell_type": "code", "execution_count": 62, "id": "63306655", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:38:01.920546Z", "start_time": "2023-03-20T15:38:01.915454Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(1.7296, grad_fn=)" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "kl_loss(input_log_softmax, target_shaped)" ] }, { "cell_type": "code", "execution_count": 63, "id": "de636da2", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:38:53.633383Z", "start_time": "2023-03-20T15:38:53.623851Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(1.7296, grad_fn=)" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ce_loss(input, target_shaped)" ] }, { "cell_type": "markdown", "id": "bc5f92ec", "metadata": {}, "source": [ "- $H(p)==0$" ] }, { "cell_type": "markdown", "id": "df51123e", "metadata": {}, "source": [ "### $H(p)\\neq 0$" ] }, { "cell_type": "code", "execution_count": 69, "id": "6fcf9969", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:40:59.553869Z", "start_time": "2023-03-20T15:40:59.549352Z" } }, "outputs": [], "source": [ "kl_loss = torch.nn.KLDivLoss(reduction='none')\n", "ce_loss = torch.nn.CrossEntropyLoss(reduction='none')" ] }, { "cell_type": "code", "execution_count": 71, "id": "1bb3351c", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:41:27.535234Z", "start_time": "2023-03-20T15:41:27.531335Z" } }, "outputs": [], "source": [ "# q(x)\n", "input = torch.Tensor([[-0.1, 0.2, -0.4, 0.3]])\n", "# p(x)\n", "target = torch.Tensor([[-0.7, 0.1, -0.1, 0.1]])" ] }, { "cell_type": "code", "execution_count": 73, "id": "4b86a65c", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:42:28.584559Z", "start_time": "2023-03-20T15:42:28.574631Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.0635, 0.0116, 0.1097, -0.0190]])\n", "tensor(0.0097)\n", "tensor(0.0389)\n" ] } ], "source": [ "# kl(p, q)\n", "kl_output = kl_loss(F.log_softmax(input, dim=1), torch.softmax(target, dim=1))\n", "print(kl_output)\n", "print(kl_output.mean())\n", "print(kl_output.sum())" ] }, { "cell_type": "code", "execution_count": 74, "id": "12c3141f", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:42:54.035273Z", "start_time": "2023-03-20T15:42:54.027528Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([1.3832])" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# H(p, q)\n", "ce_loss(input, torch.softmax(target, dim=1))" ] }, { "cell_type": "code", "execution_count": 76, "id": "143110aa", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:43:20.831440Z", "start_time": "2023-03-20T15:43:20.824927Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([1.3443])" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# H(p) == H(p, p)\n", "ce_loss(target, torch.softmax(target, dim=1))" ] }, { "cell_type": "code", "execution_count": 78, "id": "73ab9d88", "metadata": { "ExecuteTime": { "end_time": "2023-03-20T15:46:07.758263Z", "start_time": "2023-03-20T15:46:07.753794Z" } }, "outputs": [ { "data": { "text/plain": [ "0.038899999999999935" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# H(p, q) - H(p) == kl(p, q)\n", "1.3832 - 1.3443" ] }, { "cell_type": "code", "execution_count": null, "id": "4f0068c3", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 5 }