{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "9f6ad93a", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T10:09:29.696293Z", "start_time": "2023-01-07T10:09:28.230709Z" } }, "outputs": [], "source": [ "import torch\n", "from torch import nn" ] }, { "cell_type": "markdown", "id": "d34bf59a", "metadata": {}, "source": [ "- references\n", " - [BCELoss — PyTorch 1.13 documentation](https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html)\n", " - [[pytorch 模型拓扑结构] 深入理解 nn.CrossEntropyLoss 计算过程(nn.NLLLoss(nn.LogSoftmax))](https://www.bilibili.com/video/BV1NY4y1E76o/)\n", " " ] }, { "cell_type": "markdown", "id": "d39cf8de", "metadata": {}, "source": [ "## 1. BCELoss 计算过程" ] }, { "cell_type": "markdown", "id": "1ddf9667", "metadata": {}, "source": [ "- inputs: \n", " - 未经过 sigmoid 的 network 的输出(一个样本对应一维的输出)\n", "- 计算过程 & output: \n", " - step1:计算 sigmoid,将 1d 的 logits 转换为 p(class=1|x) 的概率\n", " - step2:计算 $\\ell_i=-\\left(y_i\\log (\\hat {y_i}) + (1-y_i)\\log (1-\\hat {y_i})\\right)$\n", " - step3:计算均值 $\\frac1n\\sum_i\\ell_i$" ] }, { "cell_type": "code", "execution_count": 22, "id": "3cba36d3", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:02:14.357018Z", "start_time": "2023-01-07T11:02:14.351717Z" } }, "outputs": [], "source": [ "m = nn.Sigmoid()\n", "loss = nn.BCELoss()\n", "inputs = torch.randn(3, requires_grad=True)\n", "target = torch.empty(3).random_(2)\n", "output = loss(m(inputs), target)" ] }, { "cell_type": "code", "execution_count": 23, "id": "9a366456", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:02:22.223227Z", "start_time": "2023-01-07T11:02:22.216737Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([-1.7723, -1.5965, 0.8863], requires_grad=True) tensor([0., 1., 0.])\n" ] } ], "source": [ "print(inputs, target)" ] }, { "cell_type": "code", "execution_count": 24, "id": "fae9997d", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:03:55.452492Z", "start_time": "2023-01-07T11:03:55.446176Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.1453, 0.1685, 0.7081], grad_fn=)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# [0, 0, 1]\n", "m(inputs)" ] }, { "cell_type": "markdown", "id": "fd2cc509", "metadata": {}, "source": [ "$$\n", "\\hat y=\\sigma(z)=\\frac{1}{1+\\exp(-z)}\n", "$$" ] }, { "cell_type": "code", "execution_count": 25, "id": "d09517ae", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:04:19.625453Z", "start_time": "2023-01-07T11:04:19.619159Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.1453, 0.1685, 0.7081], grad_fn=)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "1/(1+torch.exp(-inputs))" ] }, { "cell_type": "markdown", "id": "13c2126e", "metadata": {}, "source": [ "$$\n", "\\begin{split}\n", "\\ell_i&=-\\left(y_i\\log (\\hat {y_i}) + (1-y_i)\\log (1-\\hat {y_i})\\right)\\\\\n", "&=-\\left(y_i\\log (\\sigma(z_i)) + (1-y_i)\\log (1-\\sigma(z_i))\\right)\n", "\\end{split}\n", "$$" ] }, { "cell_type": "code", "execution_count": 26, "id": "7ac07baa", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:06:48.652622Z", "start_time": "2023-01-07T11:06:48.647220Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(1.0565, grad_fn=)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output" ] }, { "cell_type": "code", "execution_count": 27, "id": "6aa5c0b8", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:07:44.711358Z", "start_time": "2023-01-07T11:07:44.704928Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([0.1570, 1.7810, 1.2314], grad_fn=)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-(target * torch.log(m(inputs)) + (1-target)*torch.log(1-m(inputs)))" ] }, { "cell_type": "markdown", "id": "2026f312", "metadata": {}, "source": [ "$$\\frac1n\\sum_i\\ell_i$$" ] }, { "cell_type": "code", "execution_count": 28, "id": "fc8736fb", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:08:16.771663Z", "start_time": "2023-01-07T11:08:16.766048Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(1.0565, grad_fn=)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.mean(-(target * torch.log(m(inputs)) + (1-target)*torch.log(1-m(inputs))))" ] }, { "cell_type": "markdown", "id": "b4c8a5dd", "metadata": {}, "source": [ "### 1.1 backward" ] }, { "cell_type": "code", "execution_count": 29, "id": "1b54a9bf", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:09:39.746774Z", "start_time": "2023-01-07T11:09:39.743250Z" } }, "outputs": [], "source": [ "output.backward()" ] }, { "cell_type": "code", "execution_count": 30, "id": "8497f15f", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:09:44.563726Z", "start_time": "2023-01-07T11:09:44.557957Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([ 0.0484, -0.2772, 0.2360])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs.grad" ] }, { "cell_type": "markdown", "id": "2c3541cf", "metadata": {}, "source": [ "$$\n", "\\begin{split}\n", "\\frac{\\partial \\ell_i}{\\partial z_i}&=-\\left(y_i\\frac{\\sigma(z_i)(1-\\sigma(z_i))}{\\sigma(z_i)}-(1-y_i)\\frac{\\sigma(z_i)(1-\\sigma(z_i))}{1-\\sigma(z_i)}\\right)\\\\\n", "&=-\\left(y_i(1-\\sigma(z_i) - (1-y_i)\\sigma(z_i)\\right)\\\\\n", "&=-(y_i-\\sigma(z_i))\n", "\\end{split}\n", "$$" ] }, { "cell_type": "code", "execution_count": 31, "id": "1decf75d", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:11:02.571951Z", "start_time": "2023-01-07T11:11:02.566570Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([ 0.1453, -0.8315, 0.7081], grad_fn=)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-(target - m(inputs))" ] }, { "cell_type": "markdown", "id": "6e70a435", "metadata": {}, "source": [ "$$\n", "\\ell=\\frac{1}3(\\ell_1+\\ell_2+\\ell_3)\\\\\n", "\\frac{\\partial \\ell}{\\partial z_i}=\\frac{1}3\\frac{\\partial \\ell_i}{\\partial z_i}\n", "$$" ] }, { "cell_type": "code", "execution_count": 32, "id": "136f5d23", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:11:29.565714Z", "start_time": "2023-01-07T11:11:29.559382Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([ 0.0484, -0.2772, 0.2360], grad_fn=)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "-(target - m(inputs))/3" ] }, { "cell_type": "markdown", "id": "a07ffaed", "metadata": {}, "source": [ "### 1.2 BCELoss vs. BCEWithLogitsLoss" ] }, { "cell_type": "markdown", "id": "64e7ae63", "metadata": {}, "source": [ "- BCEWithLogitsLoss = sigmoid + BCELoss" ] }, { "cell_type": "code", "execution_count": 33, "id": "2e715d5e", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:12:21.011891Z", "start_time": "2023-01-07T11:12:21.008933Z" } }, "outputs": [], "source": [ "loss2 = nn.BCEWithLogitsLoss()" ] }, { "cell_type": "code", "execution_count": 35, "id": "398d5958", "metadata": { "ExecuteTime": { "end_time": "2023-01-07T11:12:51.673851Z", "start_time": "2023-01-07T11:12:51.668815Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(1.0565, grad_fn=)" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# output = loss(m(inputs), target)\n", "loss2(inputs, target)" ] }, { "cell_type": "markdown", "id": "c430b5b6", "metadata": {}, "source": [ "### 1.3 cross entropy loss" ] }, { "cell_type": "markdown", "id": "77f9b61f", "metadata": {}, "source": [ "\n", "$$\n", "H(p,q)=-\\sum_x p(x)\\log q(x)\n", "$$\n", "\n", "- 度量两个概率分布的距离\n", " - $(y_i, 1-y_i)$ vs. $(\\hat{y_i}, 1-\\hat{y_i})$" ] }, { "cell_type": "markdown", "id": "516315ec", "metadata": {}, "source": [ "### 1.4 BCELoss vs. CrossEntropyLoss" ] }, { "cell_type": "markdown", "id": "dd116be2", "metadata": {}, "source": [ "- 二分类 vs. 多分类\n", " - 单输出,多输出\n", "- 概率化过程:sigmoid vs. softmax\n", "- 都用的是 cross entropy loss 来计算 loss" ] }, { "cell_type": "code", "execution_count": null, "id": "59e15fe7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }