summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorchzhang <zch921005@126.com>2023-01-07 19:18:06 +0800
committerchzhang <zch921005@126.com>2023-01-07 19:18:06 +0800
commit3d3474d86db3e88782ea0642bb4fcfb1016a60d3 (patch)
tree3016591ac2ff47767a9d7d2b1204883342a3313d
parenta8f0b05e9df53fe83a72b6549b4d232d742f0636 (diff)
bceloss
-rw-r--r--learn_torch/loss/01_BCELoss_binary_cross_entropy.ipynb510
1 files changed, 510 insertions, 0 deletions
diff --git a/learn_torch/loss/01_BCELoss_binary_cross_entropy.ipynb b/learn_torch/loss/01_BCELoss_binary_cross_entropy.ipynb
new file mode 100644
index 0000000..ccbb946
--- /dev/null
+++ b/learn_torch/loss/01_BCELoss_binary_cross_entropy.ipynb
@@ -0,0 +1,510 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "9f6ad93a",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T10:09:29.696293Z",
+ "start_time": "2023-01-07T10:09:28.230709Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch import nn"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d34bf59a",
+ "metadata": {},
+ "source": [
+ "- references\n",
+ " - [BCELoss — PyTorch 1.13 documentation](https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html)\n",
+ " - [[pytorch 模型拓扑结构] 深入理解 nn.CrossEntropyLoss 计算过程(nn.NLLLoss(nn.LogSoftmax))](https://www.bilibili.com/video/BV1NY4y1E76o/)\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d39cf8de",
+ "metadata": {},
+ "source": [
+ "## 1. BCELoss 计算过程"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1ddf9667",
+ "metadata": {},
+ "source": [
+ "- inputs: \n",
+ " - 未经过 sigmoid 的 network 的输出(一个样本对应一维的输出)\n",
+ "- 计算过程 & output: \n",
+ " - step1:计算 sigmoid,将 1d 的 logits 转换为 p(class=1|x) 的概率\n",
+ " - step2:计算 $\\ell_i=-\\left(y_i\\log (\\hat {y_i}) + (1-y_i)\\log (1-\\hat {y_i})\\right)$\n",
+ " - step3:计算均值 $\\frac1n\\sum_i\\ell_i$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "3cba36d3",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:02:14.357018Z",
+ "start_time": "2023-01-07T11:02:14.351717Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "m = nn.Sigmoid()\n",
+ "loss = nn.BCELoss()\n",
+ "inputs = torch.randn(3, requires_grad=True)\n",
+ "target = torch.empty(3).random_(2)\n",
+ "output = loss(m(inputs), target)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "9a366456",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:02:22.223227Z",
+ "start_time": "2023-01-07T11:02:22.216737Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([-1.7723, -1.5965, 0.8863], requires_grad=True) tensor([0., 1., 0.])\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(inputs, target)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "fae9997d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:03:55.452492Z",
+ "start_time": "2023-01-07T11:03:55.446176Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.1453, 0.1685, 0.7081], grad_fn=<SigmoidBackward0>)"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# [0, 0, 1]\n",
+ "m(inputs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fd2cc509",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\hat y=\\sigma(z)=\\frac{1}{1+\\exp(-z)}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "d09517ae",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:04:19.625453Z",
+ "start_time": "2023-01-07T11:04:19.619159Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.1453, 0.1685, 0.7081], grad_fn=<MulBackward0>)"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "1/(1+torch.exp(-inputs))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "13c2126e",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\begin{split}\n",
+ "\\ell_i&=-\\left(y_i\\log (\\hat {y_i}) + (1-y_i)\\log (1-\\hat {y_i})\\right)\\\\\n",
+ "&=-\\left(y_i\\log (\\sigma(z_i)) + (1-y_i)\\log (1-\\sigma(z_i))\\right)\n",
+ "\\end{split}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "7ac07baa",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:06:48.652622Z",
+ "start_time": "2023-01-07T11:06:48.647220Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.0565, grad_fn=<BinaryCrossEntropyBackward0>)"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "6aa5c0b8",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:07:44.711358Z",
+ "start_time": "2023-01-07T11:07:44.704928Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0.1570, 1.7810, 1.2314], grad_fn=<NegBackward0>)"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "-(target * torch.log(m(inputs)) + (1-target)*torch.log(1-m(inputs)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2026f312",
+ "metadata": {},
+ "source": [
+ "$$\\frac1n\\sum_i\\ell_i$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "fc8736fb",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:08:16.771663Z",
+ "start_time": "2023-01-07T11:08:16.766048Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.0565, grad_fn=<MeanBackward0>)"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.mean(-(target * torch.log(m(inputs)) + (1-target)*torch.log(1-m(inputs))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b4c8a5dd",
+ "metadata": {},
+ "source": [
+ "### 1.1 backward"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "1b54a9bf",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:09:39.746774Z",
+ "start_time": "2023-01-07T11:09:39.743250Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "output.backward()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "8497f15f",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:09:44.563726Z",
+ "start_time": "2023-01-07T11:09:44.557957Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 0.0484, -0.2772, 0.2360])"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs.grad"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2c3541cf",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\begin{split}\n",
+ "\\frac{\\partial \\ell_i}{\\partial z_i}&=-\\left(y_i\\frac{\\sigma(z_i)(1-\\sigma(z_i))}{\\sigma(z_i)}-(1-y_i)\\frac{\\sigma(z_i)(1-\\sigma(z_i))}{1-\\sigma(z_i)}\\right)\\\\\n",
+ "&=-\\left(y_i(1-\\sigma(z_i) - (1-y_i)\\sigma(z_i)\\right)\\\\\n",
+ "&=-(y_i-\\sigma(z_i))\n",
+ "\\end{split}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "1decf75d",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:11:02.571951Z",
+ "start_time": "2023-01-07T11:11:02.566570Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 0.1453, -0.8315, 0.7081], grad_fn=<NegBackward0>)"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "-(target - m(inputs))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6e70a435",
+ "metadata": {},
+ "source": [
+ "$$\n",
+ "\\ell=\\frac{1}3(\\ell_1+\\ell_2+\\ell_3)\\\\\n",
+ "\\frac{\\partial \\ell}{\\partial z_i}=\\frac{1}3\\frac{\\partial \\ell_i}{\\partial z_i}\n",
+ "$$"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "136f5d23",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:11:29.565714Z",
+ "start_time": "2023-01-07T11:11:29.559382Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 0.0484, -0.2772, 0.2360], grad_fn=<DivBackward0>)"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "-(target - m(inputs))/3"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a07ffaed",
+ "metadata": {},
+ "source": [
+ "### 1.2 BCELoss vs. BCEWithLogitsLoss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "64e7ae63",
+ "metadata": {},
+ "source": [
+ "- BCEWithLogitsLoss = sigmoid + BCELoss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "2e715d5e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:12:21.011891Z",
+ "start_time": "2023-01-07T11:12:21.008933Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "loss2 = nn.BCEWithLogitsLoss()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "398d5958",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-01-07T11:12:51.673851Z",
+ "start_time": "2023-01-07T11:12:51.668815Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.0565, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)"
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# output = loss(m(inputs), target)\n",
+ "loss2(inputs, target)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c430b5b6",
+ "metadata": {},
+ "source": [
+ "### 1.3 cross entropy loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "77f9b61f",
+ "metadata": {},
+ "source": [
+ "\n",
+ "$$\n",
+ "H(p,q)=-\\sum_x p(x)\\log q(x)\n",
+ "$$\n",
+ "\n",
+ "- 度量两个概率分布的距离\n",
+ " - $(y_i, 1-y_i)$ vs. $(\\hat{y_i}, 1-\\hat{y_i})$"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "516315ec",
+ "metadata": {},
+ "source": [
+ "### 1.4 BCELoss vs. CrossEntropyLoss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dd116be2",
+ "metadata": {},
+ "source": [
+ "- 二分类 vs. 多分类\n",
+ " - 单输出,多输出\n",
+ "- 概率化过程:sigmoid vs. softmax\n",
+ "- 都用的是 cross entropy loss 来计算 loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "59e15fe7",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}