diff options
Diffstat (limited to 'learn_torch/grad/06_retain_graph.ipynb')
| -rw-r--r-- | learn_torch/grad/06_retain_graph.ipynb | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/learn_torch/grad/06_retain_graph.ipynb b/learn_torch/grad/06_retain_graph.ipynb new file mode 100644 index 0000000..c6b17d3 --- /dev/null +++ b/learn_torch/grad/06_retain_graph.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:47:42.093837Z", + "start_time": "2023-03-01T14:47:42.091660Z" + } + }, + "outputs": [], + "source": [ + "from IPython.display import Image" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:47:43.069334Z", + "start_time": "2023-03-01T14:47:43.066216Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch.autograd import Variable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## multi head (output/branch) architecture" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:48:02.745705Z", + "start_time": "2023-03-01T14:48:02.740366Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "<IPython.core.display.Image object>" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Image('./imgs/multi_loss.PNG')" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:49:09.943819Z", + "start_time": "2023-03-01T14:49:09.886836Z" + }, + "scrolled": false + }, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-9-1ef236403047>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;31m# RuntimeError: Trying to backward through the graph a second time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 306\u001b[0m inputs=inputs)\n\u001b[0;32m--> 307\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 154\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 155\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 156\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward." + ] + } + ], + "source": [ + "a = Variable(torch.rand(1, 4), requires_grad=True)\n", + "b = a**2\n", + "c = b*2\n", + "\n", + "d = c.mean()\n", + "e = c.sum()\n", + "\n", + "\n", + "d.backward()\n", + "\n", + "# RuntimeError: Trying to backward through the graph a second time\n", + "e.backward()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- when we do d.backward(), that is fine. \n", + "- After this computation, the parts of the graph that calculate **d will be freed by default to save memory**. \n", + "- So if we do e.backward(), the error message will pop up. In order to do e.backward(), we have to set the parameter retain_graph to True in d.backward(), i.e.," + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## retain graph" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:51:15.093632Z", + "start_time": "2023-03-01T14:51:15.089817Z" + } + }, + "outputs": [], + "source": [ + "a = Variable(torch.rand(1, 4), requires_grad=True)\n", + "b = a**2\n", + "c = b*2\n", + "\n", + "d = c.mean()\n", + "e = c.sum()\n", + "\n", + "\n", + "d.backward(retain_graph=True)\n", + "\n", + "# RuntimeError: Trying to backward through the graph a second time\n", + "e.backward()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## retain graph 下的梯度计算" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:54:20.850186Z", + "start_time": "2023-03-01T14:54:20.804778Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([1., 2., 3., 4.])\n", + "tensor([ 5., 10., 15., 20.])\n" + ] + } + ], + "source": [ + "a = Variable(torch.tensor([1., 2., 3., 4.]), requires_grad=True)\n", + "b = a**2\n", + "c = b*2\n", + "\n", + "# scalar\n", + "d = c.mean()\n", + "e = c.sum()\n", + "\n", + "\n", + "d.backward(retain_graph=True)\n", + "# tensor([1., 2., 3., 4.])\n", + "print(a.grad)\n", + "e.backward()\n", + "# 两次 backwward 累加\n", + "# tensor([ 5., 10., 15., 20.])\n", + "print(a.grad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- $d=\\frac{\\sum_i2a_i^2}4$\n", + " - $\\frac{\\partial d}{\\partial a_i}=a_i$\n", + "- $e=\\sum_i2a_i^2$\n", + " - $\\frac{\\partial e}{\\partial a_i}=4a_i$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## multi loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)\n", + "l1.backward(retain_graph=True)\n", + "l2.backward() # now the graph is freed, and next process of batch gradient descent is ready\n", + "\n", + "optimizer.step() # update the network parameters" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} |
