From 3d3474d86db3e88782ea0642bb4fcfb1016a60d3 Mon Sep 17 00:00:00 2001 From: chzhang Date: Sat, 7 Jan 2023 19:18:06 +0800 Subject: bceloss --- .../loss/01_BCELoss_binary_cross_entropy.ipynb | 510 +++++++++++++++++++++ 1 file changed, 510 insertions(+) create mode 100644 learn_torch/loss/01_BCELoss_binary_cross_entropy.ipynb (limited to 'learn_torch/loss') diff --git a/learn_torch/loss/01_BCELoss_binary_cross_entropy.ipynb b/learn_torch/loss/01_BCELoss_binary_cross_entropy.ipynb new file mode 100644 index 0000000..ccbb946 --- /dev/null +++ b/learn_torch/loss/01_BCELoss_binary_cross_entropy.ipynb @@ -0,0 +1,510 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "9f6ad93a", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T10:09:29.696293Z", + "start_time": "2023-01-07T10:09:28.230709Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn" + ] + }, + { + "cell_type": "markdown", + "id": "d34bf59a", + "metadata": {}, + "source": [ + "- references\n", + " - [BCELoss — PyTorch 1.13 documentation](https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html)\n", + " - [[pytorch 模型拓扑结构] 深入理解 nn.CrossEntropyLoss 计算过程(nn.NLLLoss(nn.LogSoftmax))](https://www.bilibili.com/video/BV1NY4y1E76o/)\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "d39cf8de", + "metadata": {}, + "source": [ + "## 1. BCELoss 计算过程" + ] + }, + { + "cell_type": "markdown", + "id": "1ddf9667", + "metadata": {}, + "source": [ + "- inputs: \n", + " - 未经过 sigmoid 的 network 的输出(一个样本对应一维的输出)\n", + "- 计算过程 & output: \n", + " - step1:计算 sigmoid,将 1d 的 logits 转换为 p(class=1|x) 的概率\n", + " - step2:计算 $\\ell_i=-\\left(y_i\\log (\\hat {y_i}) + (1-y_i)\\log (1-\\hat {y_i})\\right)$\n", + " - step3:计算均值 $\\frac1n\\sum_i\\ell_i$" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3cba36d3", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:02:14.357018Z", + "start_time": "2023-01-07T11:02:14.351717Z" + } + }, + "outputs": [], + "source": [ + "m = nn.Sigmoid()\n", + "loss = nn.BCELoss()\n", + "inputs = torch.randn(3, requires_grad=True)\n", + "target = torch.empty(3).random_(2)\n", + "output = loss(m(inputs), target)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "9a366456", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:02:22.223227Z", + "start_time": "2023-01-07T11:02:22.216737Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([-1.7723, -1.5965, 0.8863], requires_grad=True) tensor([0., 1., 0.])\n" + ] + } + ], + "source": [ + "print(inputs, target)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "fae9997d", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:03:55.452492Z", + "start_time": "2023-01-07T11:03:55.446176Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.1453, 0.1685, 0.7081], grad_fn=)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# [0, 0, 1]\n", + "m(inputs)" + ] + }, + { + "cell_type": "markdown", + "id": "fd2cc509", + "metadata": {}, + "source": [ + "$$\n", + "\\hat y=\\sigma(z)=\\frac{1}{1+\\exp(-z)}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "d09517ae", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:04:19.625453Z", + "start_time": "2023-01-07T11:04:19.619159Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.1453, 0.1685, 0.7081], grad_fn=)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "1/(1+torch.exp(-inputs))" + ] + }, + { + "cell_type": "markdown", + "id": "13c2126e", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{split}\n", + "\\ell_i&=-\\left(y_i\\log (\\hat {y_i}) + (1-y_i)\\log (1-\\hat {y_i})\\right)\\\\\n", + "&=-\\left(y_i\\log (\\sigma(z_i)) + (1-y_i)\\log (1-\\sigma(z_i))\\right)\n", + "\\end{split}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "7ac07baa", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:06:48.652622Z", + "start_time": "2023-01-07T11:06:48.647220Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.0565, grad_fn=)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "6aa5c0b8", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:07:44.711358Z", + "start_time": "2023-01-07T11:07:44.704928Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0.1570, 1.7810, 1.2314], grad_fn=)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "-(target * torch.log(m(inputs)) + (1-target)*torch.log(1-m(inputs)))" + ] + }, + { + "cell_type": "markdown", + "id": "2026f312", + "metadata": {}, + "source": [ + "$$\\frac1n\\sum_i\\ell_i$$" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "fc8736fb", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:08:16.771663Z", + "start_time": "2023-01-07T11:08:16.766048Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.0565, grad_fn=)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.mean(-(target * torch.log(m(inputs)) + (1-target)*torch.log(1-m(inputs))))" + ] + }, + { + "cell_type": "markdown", + "id": "b4c8a5dd", + "metadata": {}, + "source": [ + "### 1.1 backward" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "1b54a9bf", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:09:39.746774Z", + "start_time": "2023-01-07T11:09:39.743250Z" + } + }, + "outputs": [], + "source": [ + "output.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "8497f15f", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:09:44.563726Z", + "start_time": "2023-01-07T11:09:44.557957Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0.0484, -0.2772, 0.2360])" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs.grad" + ] + }, + { + "cell_type": "markdown", + "id": "2c3541cf", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{split}\n", + "\\frac{\\partial \\ell_i}{\\partial z_i}&=-\\left(y_i\\frac{\\sigma(z_i)(1-\\sigma(z_i))}{\\sigma(z_i)}-(1-y_i)\\frac{\\sigma(z_i)(1-\\sigma(z_i))}{1-\\sigma(z_i)}\\right)\\\\\n", + "&=-\\left(y_i(1-\\sigma(z_i) - (1-y_i)\\sigma(z_i)\\right)\\\\\n", + "&=-(y_i-\\sigma(z_i))\n", + "\\end{split}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "1decf75d", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:11:02.571951Z", + "start_time": "2023-01-07T11:11:02.566570Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0.1453, -0.8315, 0.7081], grad_fn=)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "-(target - m(inputs))" + ] + }, + { + "cell_type": "markdown", + "id": "6e70a435", + "metadata": {}, + "source": [ + "$$\n", + "\\ell=\\frac{1}3(\\ell_1+\\ell_2+\\ell_3)\\\\\n", + "\\frac{\\partial \\ell}{\\partial z_i}=\\frac{1}3\\frac{\\partial \\ell_i}{\\partial z_i}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "136f5d23", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:11:29.565714Z", + "start_time": "2023-01-07T11:11:29.559382Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 0.0484, -0.2772, 0.2360], grad_fn=)" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "-(target - m(inputs))/3" + ] + }, + { + "cell_type": "markdown", + "id": "a07ffaed", + "metadata": {}, + "source": [ + "### 1.2 BCELoss vs. BCEWithLogitsLoss" + ] + }, + { + "cell_type": "markdown", + "id": "64e7ae63", + "metadata": {}, + "source": [ + "- BCEWithLogitsLoss = sigmoid + BCELoss" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "2e715d5e", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:12:21.011891Z", + "start_time": "2023-01-07T11:12:21.008933Z" + } + }, + "outputs": [], + "source": [ + "loss2 = nn.BCEWithLogitsLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "398d5958", + "metadata": { + "ExecuteTime": { + "end_time": "2023-01-07T11:12:51.673851Z", + "start_time": "2023-01-07T11:12:51.668815Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(1.0565, grad_fn=)" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# output = loss(m(inputs), target)\n", + "loss2(inputs, target)" + ] + }, + { + "cell_type": "markdown", + "id": "c430b5b6", + "metadata": {}, + "source": [ + "### 1.3 cross entropy loss" + ] + }, + { + "cell_type": "markdown", + "id": "77f9b61f", + "metadata": {}, + "source": [ + "\n", + "$$\n", + "H(p,q)=-\\sum_x p(x)\\log q(x)\n", + "$$\n", + "\n", + "- 度量两个概率分布的距离\n", + " - $(y_i, 1-y_i)$ vs. $(\\hat{y_i}, 1-\\hat{y_i})$" + ] + }, + { + "cell_type": "markdown", + "id": "516315ec", + "metadata": {}, + "source": [ + "### 1.4 BCELoss vs. CrossEntropyLoss" + ] + }, + { + "cell_type": "markdown", + "id": "dd116be2", + "metadata": {}, + "source": [ + "- 二分类 vs. 多分类\n", + " - 单输出,多输出\n", + "- 概率化过程:sigmoid vs. softmax\n", + "- 都用的是 cross entropy loss 来计算 loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59e15fe7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} -- cgit v1.2.3