summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--learn_torch/loss/02_kl_divergence_cross_entropy.ipynb601
1 files changed, 601 insertions, 0 deletions
diff --git a/learn_torch/loss/02_kl_divergence_cross_entropy.ipynb b/learn_torch/loss/02_kl_divergence_cross_entropy.ipynb
new file mode 100644
index 0000000..c30c932
--- /dev/null
+++ b/learn_torch/loss/02_kl_divergence_cross_entropy.ipynb
@@ -0,0 +1,601 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "352af752",
+ "metadata": {},
+ "source": [
+ "## kl divergence & cross entropy"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e74d2853",
+ "metadata": {},
+ "source": [
+ "### 定义及计算"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6c0bdd63",
+ "metadata": {},
+ "source": [
+ "- entropy:\n",
+ "$$\n",
+ "H(p)=-\\sum_{x\\in \\mathcal X}p(x)\\log p(x)=\\sum_{x\\in\\mathcal X}p(x)\\log\\frac1{p(x)}\n",
+ "$$\n",
+ "\n",
+ "- cross entropy(交叉熵):\n",
+ "$$\n",
+ "H(p,q)=-\\sum_{x\\in\\mathcal X}p(x)\\log q(x)=\\sum_{x\\in\\mathcal X}p(x)\\log\\frac1{q(x)}\n",
+ "$$\n",
+ "\n",
+ "- kl divergence(相对熵):\n",
+ "$$\n",
+ "\\begin{align}\n",
+ "D_\\text{KL}(p\\parallel q) &=-\\sum_{x\\in\\mathcal X}p(x)\\log\\frac{q(x)}{p(x)}\\\\\n",
+ "&=-\\sum_{x\\in \\mathcal X}\\left(p(x)\\log q(x)-p(x)\\log p(x)\\right)\\\\\n",
+ "&=-\\sum_{x\\mathcal X}p(x)\\log q(x)-\\left(-\\sum_{x\\mathcal X}p(x)\\log p(x)\\right)\\\\\n",
+ "&=H(p,q)-H(p)\n",
+ "\\end{align}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "58230a01",
+ "metadata": {},
+ "source": [
+ "### 性质"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7a1a9981",
+ "metadata": {},
+ "source": [
+ "- $H(p)\\geq 0$\n",
+ "- $H(p,q)\\geq 0$\n",
+ " - 非对称\n",
+ " - $H(p,p)=H(p)$\n",
+ "- $D_{\\text{KL}}(p\\parallel q)\\geq 0$(gibbs inequality)\n",
+ " - $H(p,q)\\geq H(p)$\n",
+ " - $D_{\\text{KL}} = 0 $ 当 $p == q$ 时;\n",
+ " - 非对称,$D_{\\text{KL}}(p\\parallel q) \\neq D_{\\text{KL}}(q\\parallel p)$\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a30c0514",
+ "metadata": {},
+ "source": [
+ "## pytorch api"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c6e48572",
+ "metadata": {},
+ "source": [
+ "- $p$: true(target) distribution,label值;\n",
+ " - $p(y|x)$:one hot,一种特殊的概率分布\n",
+ "- $q$: predict(input) distribution,模型的预测值(分布);\n",
+ " - $q(y|x)$:dense prob distribution\n",
+ "- $H(p,q)=-\\sum p(x)\\log q(x)$ (退化成 log loss)\n",
+ " - $H(q,p)=-\\sum q(x)\\log p(x)$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "acf57c28",
+ "metadata": {},
+ "source": [
+ "### wiki"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "bdda1223",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:32:54.946377Z",
+ "start_time": "2023-03-20T15:32:54.941360Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "import torch.nn.functional as F"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "id": "21531b41",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:33:10.712094Z",
+ "start_time": "2023-03-20T15:33:10.705862Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "kl_loss = nn.KLDivLoss(reduction=\"batchmean\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "id": "d67a2c13",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:33:39.814970Z",
+ "start_time": "2023-03-20T15:33:39.805253Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 3]) tensor([[-1.0986, -1.0986, -1.0986]])\n",
+ "torch.Size([1, 3])\n"
+ ]
+ }
+ ],
+ "source": [
+ "input = torch.log(torch.FloatTensor([[1/3, 1/3, 1/3]]))\n",
+ "print(input.shape, input)\n",
+ "target = torch.FloatTensor([[9/25, 12/25, 4/25]])\n",
+ "print(target.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "id": "bb8b4638",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:33:40.891404Z",
+ "start_time": "2023-03-20T15:33:40.885385Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.0853)"
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "kl_loss(input, target)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "42044238",
+ "metadata": {},
+ "source": [
+ "### kl loss vs. ce loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "72961f24",
+ "metadata": {},
+ "source": [
+ "- kl loss:\n",
+ " - `input.shape == target.shape`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "id": "ad5a5c32",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:34:47.656453Z",
+ "start_time": "2023-03-20T15:34:47.653460Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "kl_loss = nn.KLDivLoss(reduction=\"batchmean\")\n",
+ "ce_loss = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "id": "cdf1e537",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:34:56.714500Z",
+ "start_time": "2023-03-20T15:34:56.706054Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 0.3286, 0.4247, 1.5101, -0.4628, -0.6365],\n",
+ " [-0.1079, -0.6032, -1.1911, -0.1564, 0.9218],\n",
+ " [ 1.5526, -0.2029, 0.7564, -2.0878, -1.1310]], requires_grad=True)"
+ ]
+ },
+ "execution_count": 55,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "input = torch.randn(3, 5, requires_grad=True)\n",
+ "input"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "id": "4c19e7ba",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:35:51.836659Z",
+ "start_time": "2023-03-20T15:35:51.829638Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0, 2, 0])"
+ ]
+ },
+ "execution_count": 57,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "target = torch.empty(3, dtype=torch.long).random_(5)\n",
+ "target"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "id": "fb11bab0",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:35:53.779703Z",
+ "start_time": "2023-03-20T15:35:53.771265Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.7296, grad_fn=<NllLossBackward0>)"
+ ]
+ },
+ "execution_count": 58,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ce_value = ce_loss(input, target)\n",
+ "ce_value"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "id": "3a8261e2",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:36:46.037218Z",
+ "start_time": "2023-03-20T15:36:46.031049Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[-1.8236, -1.7275, -0.6421, -2.6150, -2.7887],\n",
+ " [-1.7406, -2.2359, -2.8238, -1.7891, -0.7109],\n",
+ " [-0.5414, -2.2969, -1.3376, -4.1818, -3.2250]],\n",
+ " grad_fn=<LogSoftmaxBackward0>)"
+ ]
+ },
+ "execution_count": 60,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "input_log_softmax = F.log_softmax(input, dim=1)\n",
+ "input_log_softmax"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
+ "id": "79ee5e65",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:39:48.653273Z",
+ "start_time": "2023-03-20T15:39:48.649184Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "target_shaped = torch.FloatTensor([[1, 0, 0, 0, 0], \n",
+ " [0, 0, 1, 0, 0], \n",
+ " [1, 0, 0, 0, 0]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 62,
+ "id": "63306655",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:38:01.920546Z",
+ "start_time": "2023-03-20T15:38:01.915454Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.7296, grad_fn=<DivBackward0>)"
+ ]
+ },
+ "execution_count": 62,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "kl_loss(input_log_softmax, target_shaped)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "id": "de636da2",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:38:53.633383Z",
+ "start_time": "2023-03-20T15:38:53.623851Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.7296, grad_fn=<DivBackward1>)"
+ ]
+ },
+ "execution_count": 63,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ce_loss(input, target_shaped)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bc5f92ec",
+ "metadata": {},
+ "source": [
+ "- $H(p)==0$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "df51123e",
+ "metadata": {},
+ "source": [
+ "### $H(p)\\neq 0$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "id": "6fcf9969",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:40:59.553869Z",
+ "start_time": "2023-03-20T15:40:59.549352Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "kl_loss = torch.nn.KLDivLoss(reduction='none')\n",
+ "ce_loss = torch.nn.CrossEntropyLoss(reduction='none')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "id": "1bb3351c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:41:27.535234Z",
+ "start_time": "2023-03-20T15:41:27.531335Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# q(x)\n",
+ "input = torch.Tensor([[-0.1, 0.2, -0.4, 0.3]])\n",
+ "# p(x)\n",
+ "target = torch.Tensor([[-0.7, 0.1, -0.1, 0.1]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "id": "4b86a65c",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:42:28.584559Z",
+ "start_time": "2023-03-20T15:42:28.574631Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([[-0.0635, 0.0116, 0.1097, -0.0190]])\n",
+ "tensor(0.0097)\n",
+ "tensor(0.0389)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# kl(p, q)\n",
+ "kl_output = kl_loss(F.log_softmax(input, dim=1), torch.softmax(target, dim=1))\n",
+ "print(kl_output)\n",
+ "print(kl_output.mean())\n",
+ "print(kl_output.sum())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "id": "12c3141f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:42:54.035273Z",
+ "start_time": "2023-03-20T15:42:54.027528Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([1.3832])"
+ ]
+ },
+ "execution_count": 74,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# H(p, q)\n",
+ "ce_loss(input, torch.softmax(target, dim=1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "id": "143110aa",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:43:20.831440Z",
+ "start_time": "2023-03-20T15:43:20.824927Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([1.3443])"
+ ]
+ },
+ "execution_count": 76,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# H(p) == H(p, p)\n",
+ "ce_loss(target, torch.softmax(target, dim=1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 78,
+ "id": "73ab9d88",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-03-20T15:46:07.758263Z",
+ "start_time": "2023-03-20T15:46:07.753794Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.038899999999999935"
+ ]
+ },
+ "execution_count": 78,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# H(p, q) - H(p) == kl(p, q)\n",
+ "1.3832 - 1.3443"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4f0068c3",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13"
+ },
+ "toc": {
+ "base_numbering": 1,
+ "nav_menu": {},
+ "number_sections": true,
+ "sideBar": true,
+ "skip_h1_title": false,
+ "title_cell": "Table of Contents",
+ "title_sidebar": "Contents",
+ "toc_cell": false,
+ "toc_position": {},
+ "toc_section_display": true,
+ "toc_window_display": true
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}