diff options
Diffstat (limited to 'learn_torch/grad/06_retain_graph.ipynb')
| -rw-r--r-- | learn_torch/grad/06_retain_graph.ipynb | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/learn_torch/grad/06_retain_graph.ipynb b/learn_torch/grad/06_retain_graph.ipynb new file mode 100644 index 0000000..c6b17d3 --- /dev/null +++ b/learn_torch/grad/06_retain_graph.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:47:42.093837Z", + "start_time": "2023-03-01T14:47:42.091660Z" + } + }, + "outputs": [], + "source": [ + "from IPython.display import Image" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:47:43.069334Z", + "start_time": "2023-03-01T14:47:43.066216Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch.autograd import Variable" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## multi head (output/branch) architecture" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:48:02.745705Z", + "start_time": "2023-03-01T14:48:02.740366Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMEAAAExCAIAAABKxVFmAAAnm0lEQVR4nOydeXwURd7/q7vnTDIhh0m4AjGaoILKsa6KyIIRlgCyIq5IBJT8EJBLgrg8KAiIyIMucugDikqABaOAKC4CIQRCCAEEwv38FlgRQg6SyZ25u7vqefV0dhySmckcPdPTPf1+zR8z09VV3zAfqr5V9e36yhBCQELCB3C+DZAQPJKGJHxF0pCEr0gakvAVSUMSviJpSMJXJA1J+IqkIQlfkTQk4SuShiR8RcZXw0inI48UUhcuwso7zKtaixoasYhwvGMC3qkTHhdL9H5EkTYY69CBLwsl3AQL9H4ZSZp3/mDZd4A8cQqQZDuFCVz++GPyoUOUGS9hKlWALJTwkABqCELzjz8ZP1oDKyqYjzgu69dHMWwInpiIx8fhCXFYZCRqaITaGlhVDSvvkHn55MlTgIZM2YR4ddZs5dgxgCACZK2E2wRIQ2RRsWHpcvrqdaZzuf8+5YQMxYhheHyc67tQbZ1lf645Zwd16QpzY3JS2KIF8rTBATBYwn0CoSHjuvXGVWsBQnjXLuq33lSOHgVwT3x5hCz7Dhg/Xk3fuAkwTJ01Sz1nph/NlfAQP2uIJHVz51v27GV++zkz1TOnAbncy6ogNG3Zbli2AlCUYmR6xJqPgULBsbUSXuFPDZFk09gJ1JkSLDws4tNP5M8+43uV1NkS3eQZsLZW9nBPzfc5kqMdDPhxfUg/7x3qTAme2DVy725OBAQAkPXrG5n7E/FAKnXpin7WXCAFYQYB/uqHTOs3Gv7771h4WOT+PURSd3dugWXlDf1b+8sxpdcclKzWNg4dierq1TOnqf82lyOTJbzEL/0QefykYeUqgGERG9a5KSAAADtrcwc8Pk6z6QtAEMbPPicPHfbBUgkO8IOGEDK8+x5ASD1npnzQQPfvI3qkuF9Y1rd3+PuLmBHzvWWApr0yVIIbuNeQ+fsf6Rs38YR49aw3PDOla5eY0msxpdc02RvdKa8cP47okQLLys05O7w1VoIDuNYQSRpXrgIAqN/OAjI/b8ZhWNiiBQAA46q1yGTyb1sSzuFYQ+Zvd8KqauLeJOWLo7mt2SHygQNkfXvD2jrzP74JQHMSDuFaQz/+kxllXh3v2Uq0D6imTgYAWH4+EJjmJNrC5XCDdDrq7DkAgOK59PYLm0z1qY+0+jLm1lWAYR41qkgbhCkU1LkLqLFRChThBS41ROYdBhASDz2Ix7naTHWoHpa67j08blWhkD39FJl/xJJ7SPnSGI9vl/AZLkcc8kxJu50QmV/gTEBeoxw1wroNco7baiXchMt+CFZUAgCIbt2cFigrb540xfZRk71RnjbI9pHML7C/6j74vUlM5VXVXtwr4TvcaqiC+UU7Jji+jJD9VkZb10eeNij62kUveim8U0em9fJyb4wWBfwGFnOqoco77EaEw6vk4aO299HXLjr0nTGVKqr4SNtdM9fgsbFM63eqPDdZ4LgMLEZNTdaR4Tz7Ue+3wGJu52V6RgdO+iHbOKWaNNHF34B37eJxwzKCab2xyeMbhUvbwOLH+rUbWEwWMy/Tho3cBhZzuW/f0H8wLCuPvnTaQZ+JkG3OFVV8xLVQ7B0jh/v2retubq7v2Q8QeMxv//LeeuEQbIHFXM7LWvySai2HdboDO4q5XlAQDcZ165tfmURfvY537RK++qMOh35WTZrQroCY3iI2Rjl+XOTe3REb1hLJSfSNm82Z04xrPvPdJE411KWzO36J+zEebsLOyNjWxQxJ6mbNNf59DQBAnTUr6uhB5ZjnPd4PwDDFiPQOhw+ELV0ECML4yTrd9DeBxeKLXZxqKCGe+UVvlTq4hmGqSRPZt+1M4BHydIYPS8tsvaBosQYWW/bsxcLDNF9vUGfN8j4y3eo/qSZNiNy5DY+Ntezd3zR6rC+b1lxqSNabmZZb9u5zeFX1+iTbezK/wFklhiXLPW3Xsj+Xab3Po57eKCCCObCYS58amUz1D/Zmpu6XS7DwsLYFDIs/MGVvtX1stUTkeAetPZ+6pVEaRp0oEOtw5kVgsfv4HljMcTx1c8ZrZFFx+JqPlC8877BAXbdU72p2JibLTz/rZmYRyUkdCg56V3OQQx4/2ZzxKgBAs+VLj+JC3YcqOd80Zhygac2mz73o5DiO0FCMGAYAMH+1xVmBdsMUNdkbHRZwNmCbNjEdm2LUSK/sDXq8DSz2CB8DiznWkPKlMXhCPHX5CpmX76yMPG2QQyVpsjfGlF6z30Gzx+GyJFlYRJWcw8LD7J0tMeF1YLGn+BRYjLjG9N2u2sSUhj8NQTTNeeV3AWHDs8NrE1MMn23wb0N8YbHU/+Gp2sQU03e7AtHa0WO1iSn1vR+HRqNHN3Ifbah8cTS7hGXavI3zyu0xf7eLvnodj41Rv57p14b4QiiBxX6IWMXxsCULmVnYsg+psyXc12+FunxF/+4SAEDYe++I9cF7oQQW+8U4+aCB6jdnABo2Z07zR1gPqqvXvTYFkKRyQoZi9CjO6w8GPAos5gr7wGL37/KXwNVzZyvSh6L6huYJmVDL5Q4a1GqZOqu18v5PsLMJUeJmYDG7XNLqZVj8gZetWgOLAUKW3EPu3+S3ThLDwtetkj3ck/7XtaZhf6FKznNSK3X+YlP689SlK0RyUsRX60V8LFq7gcXIZGIV0/aSKXtryyXPF/+8CCz240CLKZWa3d8qhg+D2pqmv75i3pbj0zEdCJlzdjSNGcf0QE8+HrlnJxYRwaW5QYbrwGI3w9LruveAZZ6Fd3oRWOxfZw1TKiM2rFXPmQlIUv/O4sahI707YoEsLGoa+YJ+/kLGB8oYq/lms+gfA3IRWNwq8JxdV7N/2Rdu6D/Yo/1ULwKLA3Qeo+VAnmHZCni7zLoq2kc5cZxi6BAsItz1XchoJPMOm7d/S5441XKy54J5znZRREb9I4+hhsaoY4fw7nd1Ra12FZ09kdfqIB53QvlaoOi65AexDpHRl864eUeAzqdWDBuiePYZ07Yc09rPqJJzVMk5vVwuHzhAMWwI3qUz3jEBj43FoqOQTg+rqlFNDV1WTh7MJw8XILOZUXqHSPX0qcrMiZhSGRiDecdZYDF1/KTtvbOwdDakWJO90dZdwbJyd4OMPQ8sDvT51MhgNG/eajmQR52/6E552UMPgkiNZsNazBp4Hzo4Cyz2btO61WNYLvAisDjQ5+RjYWrV9Kmq6VNhtZbMzaNKzkNtDayohNVa1NSERUbi8XF45054xwSiV0/F0DS8c6e6bqlNL74S/slKcUcItQLv1BGWlcNqLWGvIW//w5OFRW5qyIvAYt5yLeDxccoJGcoJGe2WxGKi6V9vNI0eq5qSGTZvjlhXpVuBd+kMTp+Fd6qIlPsD2a4XgcW8ach98HvuoevqAYSmz78iD+aHr1oh69eXb6P8zu+BxU8/9fu3d3s/HnjKbuNFYLEA8gax/5os9I3fml4YZ1j6IetrixhngcW2sHTWU+a8XS8Ci4Wgobh77vqMkOnrzU3PjqBOn+XNJv8jTxsMCJw8dRrpDfbfqxfMs71vd+2HzC9gF6zdXCJCJhNZdJyZRw8f5r6pAtAQluBgnY2+Vdr0YoZh4VJkNPJhlN/BVCr5k08AGlpyD7b63j58rz71kba7Y7ZtENvc3rji7+40Sh7MBzQkkpM88ocEoCH7sewuEDJt3d74TDpZfDLQNgUEZ4HF8rRB9iOabXfM9mq7DSIfOMCdFr0LLBaChuLvcXEVllc0j3tV/1+LWvX5IsBFYHHY0oVRxUfcrMfNxSGvA4uFoCFHY9ldIGT+5jvDEm8DHoIWuVz9dhYAwLB8JYCw1cWWo5hvXXV2d/S1i+z2mVsrQwgZlq1gfPaZ0zCNxiMzA55H0XPom7caBw5xUQDv1DFs8buK4X8OoFGBAsLGZ4bRN26q354r7/+4rF8fP7Vj/nan/m/v4rExUacKPV2BE0I/5OxQLAAwmUw9Y2qHowdFKSD6xm/mXT/giV0BQsaPVll+3u+nhnwMLBZAP9Ry1mcbO/FuiZp/fE1Y411EAkVTl69Qp89QZ0qoX87C2tqW761/OxYT3eHgXqczDG9BdfWNQ0fCaq1yQkb48iVe1CCAdWr2pDNYU2P7iMXGoNo6WFGJ+fsofv+D9AbqbAn1i1U3584jo6OFHAwjUu6jr//aPCFTsz2bw1NyoFare22Kj4HFAhjLmH/DuJZNe3bwijpRoHhuOKAo/YL3+DbNeywH8prSn6/v1bd5fKZx3Xqy+KRjAbFrQju3B21gsTA0xG7fyAcOiMzfp57/FqZShS1diKlVZGEReaSQb+u8RP6nAchsZhNnu0Y5MQOPiQnawGJh+EPGVWuJBx9o5TibvtxkWPbfeOfOUcfyfDqMhz/g7bLGP49COp2rQgQRfeY4FhsDrL+9cfWn7OFlRI+UsPlveXHEAllYZFy5ij1RT5kxlvGBfHu0QRgacgxNN6YNp2/8ps6apc6axbc1XkIeP9H8yqS2yz82lC//Nfyju85kCrbAYiFrCADq9NmmMeOAXB5VmCfcw4fYDtXxNQyLKj6Md2kTxkrRbGAxrK0D1tVIHgOLha0hAIBuxhzLP/fJBw/UbPmKb1u8BDU0Nj6bDqtr2l5SDB8W8fk6pzd6Gljcq6c8fahqYga3T8UIXkOwpqbxqWeQ0aTZ8pV8sL9O6PEf9LXrzZOmwttlGEGgNmf/dMjbS/RoP4Da/cBif/wJgtfQXc510SG/J2/kFMv+g/o585DRRDyQGv7xCt2rk2Fdve2qfEB/zTebeTXQLYQxt3eNKvNVIvleWFFhXO9WItigACHjJ+t002Yho0kxIj3yp12yRx+O2PQFZjdFUs+dzauJbuOvE5ECC/nLmdrElLr7etIVlXzb0j5Qb2ie/EZtYkptt1TDp3cdwGX65rvaxJT6+3o2jnyBPwM9Qwz9EOMtPtZP8dxwZLHo31nMty3tAEtvNz03xpJ7CAsP02RvVM+cZn9VOe4l5SsvI5VKnSWQTkg0/RBCiNZq61Ifrk1MsRw9xrctTrEcO17Xq19tYkrDwCHUrzccF6Io/dIPA22ZD4jBp7bR4lwndo06ejAInWvTV9mG5SsBDeWDno5Yv67dVUGhIJKxjKXFub5dFnTONUnqs/5meH8FoKF6+lTNlq9EIyCRzO3tYVeuMaWyQ1E+56E23gGrtbr/N426cAlTKsNXr1SMHM63RRwjqn7od+fabDYs9CacinOocxea0v9CXbiEd+oY+dMu8QlIhP3QXSvXOVvkTz3JfslLxlPzju/1C94DJCn74x80X67HoqM4rDx4EKGGbM41kZzUIe9nFxlPW8NhxlOaNnyw0vT1ZsZLm/hK2NKFIj46UpwaAjTdNGYcvF0G5IrfM57269NuxlM2IgxPiPcl4ylqbNRNm00ePwFksvDlS5TjXuL+DwwmxKkhHjOe0tf/3fzaFHi7DI+Njfh6g6xvb9/+FCHA9wIV9xjW/k9tt9TaxJT6JweZdv3gcdoQCM179zX8aUjLXsTqT92/1bz/ILvO2Tj8ebqq2mPThYm4NGSxNM/MavntP1mHLBbvq6Jp46attfc+WJuY0vzGbGQ2t1MeQsMn61jt6t6c51PTQkNEGrJYGkePrU1MqXvgUUtePidVkmfO1vd+gu1XXGTTgQZD8+vTGe0m9TB+uYmTpgWEeDSkmz2PGb/6D6b+/SuH1dJV1Q1DRjC90eQ3EIQOCpTebngmndFur36WY8c5bFooiERDxv/5gu2BqN9ucl45XVVd9+gfaxNTDCtXtbpkKSpu2UN9Jp0uvc1504JADBqyFJ1gHJFuqZYjR/3UBHn2XG3SA7WJKfajpPHrzbVJPZgu6vXp0GDwU9PBj/A1BCE7h2KcaH9i2rqdneshikIWi27u/N+dd0djXOgg+PUh864f9HPn4wnxUScK/BvvgVDj0JH01evhHy615B4ijx7DwtTh61Yphj7rx0aFgMD3XEnSuHIVAED9dpbfA4YwLGzRAvahW2CxEPclR+7dLQlI8BriK+Op/NnBHY4cIO6/LwCNBj8C15BAMp6KGwH7Q0inq+/1BwBh1Nnj7R7J0yplk43oaxcxlYpNpMK+b6dVi6X+ob6IJKMv/iL6HGpuIuB+yM2Mp7Cs3OFxuyz1qY/YMvG4dYizVxlPxY2QNdRexlMAgGHxB/ap4DjBi4yn4iboHn5wH9cZT9lMA6bsrbaPUcVHWuWBMyz+wL6Am3iR8VTcCLgfcpHxlHWA7JOexpRea5tIMGzpQhfnOzvDi4yn4kbIGqq8w+ZBc3jV3rlxJRQMs89b4A64NaMjmytOQtgacpbxFLCpPP4zSGmyNzpLespiy8Rjn5LHFZ5nPBU3AvaH8I4JsKwcmEygzXle9tnNiB4pruvBVCqPksmh5mZrvQL+78ctAv6HaPFLqrWui2H3cJxM2IuMp+JGyBqyHsDo0C+xP2kQ1dRy264XGU/FjZA1ZMt42hY7B8j0ZbY7tbGZv9w59NmLjKfiRsAacpbxlMWWbNCUvZXML3BdlW2pmjx8tN12vch4Km6EvF9mMtU/2BsAEH25BAsPa1vApgwWh9thsKzcfiG77Tqk40ZpGHWiQBrOWAQ8L2MznpJFxZbcgw7P6o65dbWuew/bR2dbZva4FpDXGU/FjYDHMhcZT1vAsJjSa24uIWqyN7ozw/cu46m4EfBYBqxxjA39B8Oqas3XG+RD0lwUbDVm2RNz66rrRcjfWyssah6fiYWHRf1yzNOElSJG4Bpiz2eZt4BITupw+IB/I9H+E0+tnj9XPWOaHxsSGsIeywAAyhdHE8lJ9I2bps3b/NqQ+btd9NXreGyM+vVMvzYkOASvIYDjYUsWAgAMyz6kzpb4qREfM56KG+FrCAD5oIHqN2cAGjZnTvNHWA+qq9e9NgWQpHJChmL0KM7rFzpi0BCblkCRPhTVNzRPyITadnbQPAJqtUydvmU8FTeC96ltILO5+YWXqUtX8Lh7Ir5cz8nhUdT5i7rJb8BqLZGcFLl3ty8JK0WMSPohdp81aDOeihvx9EMtBGXGU3EjOg1ZCbaMp+JGnBoCQZbxVNyIV0NWgiTjqbgRuYZsOM54Gh4O5HLZI738nfFU3ISKhhxA0Y1DRiCdLup0Ed+mCBvxzO09xbhhI/3rDVhVTZ27wLctwiZENQRLb5us83/2JDW+zRE2Iaoh3Vv/hf6TAsayZy+AkG+LBEwoasi8+0fq1GnbR9TURB4/watFwibkNIQaGg2LP2j1pTSc+ULIaUj/3vttH5Un9x9sP7uZhBNCS0Nk8UmL9QjHViCTyZJ3mA+LxEAoachi0c9b4OyieefuwFojHkJIQ8ZPPoVlTg+eIgsKkd4QWItEQqhoiP73r8YvvnRZgrb8vD9wBomI0NAQQvqsv7G5Wl0gDWfeERIaMm/LoS5carcY9csZVFcfEItERUjsucKKSni7DFZU0hUVsLwSVlRQl6+g6hr7MhiOI5ks7J2/qTI9O55RQsBnNrgP3rlTq6AOw9IPTV9vVk6ZpHr5JVhewbwqKmHlHVRXx5+ZQiUkNNQWWGYNk733XuL++6TULT4SEv5QW2B5pbV/ks5/4YAQ1RDNno/eVdIQB4SkhhBi51/tnlgl4Q6hqCGotc7IZDJMrebbFjEQkhqy7nhIB7tyRUhqiHWGpEc4OCIkNXSb6YcIyRniiJDUENsPJXbl2xCREIoaosuksYxLQlFD7CK1NLHnitDUUIWUtIVDQk9DNI30esanljTEESGnIfbQT0yhkI535YrQ01C5tFPGMaGnIevhaHgXyaHmjJDTEG3Nii851BwSchpq2SyTNMQdoach1h9KlMYyzghRDRGSP8QdoachdiyTNjq4I8Q0RJLIZLIGD0ka4ozQ0hC0TsqwMDWQSefec0ZoaYhmHWrJGeKU0NIQ2w9Ji9TcEmIaYhepu0rRZ1wSYhpqmdhLDjWXhKKGpMdbuSXENMQuDkmL1JwSYhqqvCP5Q5wjzvOHkE5HHimkLlyElXeYV7UWNTRiEeGoqRkgpHzxeaLPo4q0wVKCKU4Ql4ZI0rzzB8u+A+SJU+2fN03g8scfkw8dosx4CVOpAmShGBGLhiA0//iT8aM17LNjAMdl/foohg3BExPx+Dg8IQ6LjEQNjVBbA6uqYeUdMi+fPHmKPaERT4hXZ81Wjh0jJW31DjFoiCwqNixdTl+9znQu99+nnJChGDEMj49zfReqrbPszzXn7GBz/xLJSWGLFsjTBgfKavEgeA0Z1603rloLEMK7dlG/9aZy9CiAezJRQMiy74Dx49X0jZsAw9RZs9RzZvrRXDEiZA2RpG7ufMuevcxvP2emeuY0IJd7WRWEpi3bDctWAIpSjEyPWPOx9NSH+whWQyTZNHYCdaYECw+L+PQTL5LYt4U6W6KbPAPW1soe7qn5PkdytN1EqOtD+nnvUGdK8MSukXt3cyIgAICsX9/I3J+IB1KpS1f0s+YCgf7vCjiC1JBp/UbzD3uw8DDN9mzivmQOa8bj4zTbsrGYaEvuIePHqzmsWcQIT0Pk8ZOGlasAhkVsWEckdee8fkZGm74ABGH87HPykJSQqn2E5g8h1Dj4z/SNm8wEKmuW/ZW6bqmtysaUXvO6HfM/vtG/uwTv2iXq2CFp3cg1AuuHzN//SN+4iSfEq2e94deGlOPHET1SYFm5OWeHXxsSAYLSEEkaV64CAKjfzgIyP5/wj2FhixYAAIyr1rJh/BLOEJKGzN/uhFXVxL1JyhdHt70aU3otpvSaJnsjV83JBw6Q9e0Na+vM//iGqzpFiaA0ZE3Fqnx1vGcr0T6gmjoZAGD5+UBgmhMogtEQ0umos+cAAIrn0gPWqCJtEKZQUOcuoMbGgDUqOASjITLvMICQeOhBPK6dzVQuUShkTz8FELLkHgpco0JDOBo6U+JdJ0TmF9R1S7V/GRZ/4P7tylEjrNsg5zxtN3QQjIbYR8OIbt3cv4VVTPOkKa2+N2VvdV9J+L1JthP4JBwiIA1ZH8nomMBVhayS2t0UY9N6wHKnSc0lBJNHsSWcvr3IMofE3LoKMKzlA0J13XvYLtV17+F6ORuPjWVav1PlRbshgmD6IaTTAwAwD/shTfZGRiI2AVkXD1stI5H5Ba6qsJ7ugBqbvLA5RBCMhlpGMU+WjFWTJsrTBjm8JE8bpJrUkj+acZicj2iouRlYA/g9NTh0EMw/TYtfUq11/xb1gnluXmWff3UIO4oFdEFBaAhHQ9ZTOD3ySzCl0tVVuzBFNp7fIeyMTDoD1AXC0VBCPPOL3ir14B57N8hbYGmZlHTRNYLRkKz3IwAAy959Htzjet7uXuCUZX8u03qfRz1oN8QQjIbkaYMBgZOnTiO9wc1bXHg5ra7KnnrCYRlkMpFFxwEAiuHDPLQ3hBCMhjCVSv7kE4CGltyDbt7S0H+w09AfhBr6D7av3GEp8mA+oCGRnCT5Qy4QjIaYzmAE0xmYv9ri/i31qY+0Xf4h8wvslxmjr110drtp01am3VEjvbI3VBBUPDVJNvQfDKuqNV9vkA9Ja3WxbTy1O0QVH3GWUJEsLGoen4mFh0X9cgzTaLyyOCQQUj8E5HL121kAAMPylQBC3+uLuXXVaUZOhAzLVgAAVDOnSQJyjaA0BIDyxdFEchJ946Zp87Z2C7PRsbb16LaXXEz+zd/toq9ex2Nj1K9n+my1yBHUWGaFLChsnjgZEHjkrm9k/fr6ownq8pWmv7wESDJi7d8Vo0f5owkxIbB+iBnQBg1UvzkD0LA5c5o/wnpQXb3utSmAJJUTMiQBuYPwNAQAUM+drUgfiuobmidkQq0HO2jtArVaps5qrbz/E+HvL+KwZhEjvLGMBZnNzS+8TF26gsfdE/Hlelnf3r7XSZ2/qJv8BqzWEslJkXt3YxERXFgqfgTZD7H7qZrd3yqGD4Pamqa/vmLeluPTMR0ImXN2NI0Zx/RATz4euWenJCD3EWo/1AJCxtWfGtd8BgAgeqSEzX/Li3NkyMIi48pV7Il6yoyx4cuXSA/Ye4TANWTFciDPsGwFm4tD1rePcuI4xdAhWES467uQ0UjmHTZv/5Y8carlZM8F85QvPB8oq8WDGDTEQNGmbTmmtZ/B2jpgXY2UDxygGDYE79IZ75iAx8Zi0VFIp4dV1aimhi4rJw/mk4cLkNnM/BN0iFRPn6rMnOg63kjCGWLRkBVkMJo3b7UcyKPOO90Cs0fWq6c8fahqYoZ02LkviEpDNmC11rJztzlnB57UHVZUwmotamrCIiPx+Di8cye8YwLRq6diaJqU1ZUTxKkhZDQ2/eWvsOJO9OUzfNsifoQ6t3cFQroZWfS/rqGmJlhTy7c14keEGjJ8+JHtHEX6wiW+zRE/YtOQecf3pi++tn2kLkoa8jui0hB1+qx+/sK7vjl/gT9zQgXxaAjeLmvOnApo2v5LNyf5Er4gEg0hnb55fGbbp+JRfQOqb+DJqFBBFBqiad3kN+jfbjq8KHVF/kYMGtIvep8sPunsquQS+RvBa8i0eZt5W46LApKG/I2wNUQWFhmWtHMknjSW+RsBawjV1ulmZrX7kBDjVrNnCEn4BwFrCIuNiTpdpNm2ST19qqz3Iy4Cx6hz0nDmR8Sz54qamy27ftQvXgYIDNB3/VHqeXPUs6fzZ5rIEXA/1ApMo0FGIwBAMXJEVMmJiPVrlBljiaTuAMMkl8ivCOZcWHcgi4oBAPIB/fF7YhUjhytGDmcPlKWv/C/fpokZ8YxlgKbrH+iNzOaoM8e9O4JYwjvEM5ZRFy8jsxlPiJcEFGDEo6GWgWzQQL4NCTlEpKGCQkZDTz/FtyEhh1g0RNFUyXlJQ7wgEg1R5y8Amia6d8Oio/i2JeQQiYbII0eZTuhPT/NtSCgiFg1ZHWrZQGkg4wFRaMhiYWPv5U/159uUUEQMGiJ/OQNoSPRIwcLD+LYlFBGFho4VS84Qj4hBQ1RhEaOhgQP4NiREEbyGkNFI/e//Bzgue/wxvm0JUQSvIerUaYCQ7OGe0ulBfCF4DZFHpYGMZ4SvIdYZkhxq/hC2hpBOT1//NyAIWd8+fNsSughbQy3L03/oy2YQl+AFYWuIYmOGJGeIV4StIfLYcckZ4h0BawjVN9C/3cSUSlmvh/i2JaQRsIbI4hOMM/TEHwEu4L9CBAj4X79lVi8FLvKNkDVUdMLqDEkONc8I6fkypNORRwqpCxdh5R1YeYe6eBlQFN4xgXl16oTHxRK9H1GkDZYOvQ8wQtAQSZp3/mDZd4A8cQqQZDuFCVz++GPyoUOUGS85y1ovwS3BrSEIzT/+ZPxoDayoYD7iuKxfH8WwIXhiIh4fhyfEYZGRqKERamtgVTWsvEPm5ZMnTwEatuQBypqtHDtGSiTlb4JXQ2RRsWHpcvrqdaZzuf8+5YQMxYhh7T7DimrrLPtzzTk72HxkRHJS2KIF8rTBgbI6FAlSDRnXrTeuWgsQwrt2Ub/1pnL0KM8m8AhZ9h0wfryavnETYJg6a5Z6zkw/mhvaBJ+GSFI3d75lz17mt58zUz1zGpDLvawKQtOW7YZlKwBFKUamR6z5GCgUHFsrEXQaIsmmsROoMyVYeFjEp594kVizLdTZEt3kGbC2VvZwT833OZKjzTnBtT6kn/cOdaYET+wauXc3JwICAMj69Y3M/Yl4IJW6dEU/a65PqYMlHBFEGjKt32j+YQ8WHqbZnk3cl8xhzXh8nGZbNhYTbck9ZPx4NYc1SwSRhsjjJw0rVwEMi9iwjkjqznn9jIw2fQEIwvjZ57bMVBKcEBwaQsjw7nsAIfWcmf47QEjWt3f4+4uYEfO9Za1Sw0j4QlBoyPz9j/SNm3hCvHrWG35tSDl+HNEjBZaVm3N2+LWhkCII5mUk2dB/MKyqDv/7CuVLY5yWyi9onjSl7feqSRPDli70oLXCoubxmXhsTIcTBdIcjRP474fM3+6EVdXEvUnKF0c7LGBY/EFdt1SHAmI88eytdd1SDYvbybhgQz5wgKxvb1hbZ/7HNz5YLfE7QaChH//JjDKvjne2Eu1OuLQpeyuZX+Bmi6qpkwEAlp8PeGiphGN4HsuQTlff6w8Awqizx/E4b85zreuWansffe2iW8OTxVL/UF9EktEXf5ECRXyH536IzDsMICQeetA7AQEAYm5dtb1HbmYiVyhkTz8FELLkHvKuUQl7+NbQmRLmN30u3fsqMMz2lt3kdwflqBHWbZBz3rcr8R94zrUAKyoBAES3bu4UNiz+wJS9lZN28XuTmNarqjmpLcThXUMVzC/aMcF1MXunhxPwTh2Z1svLua02NOF5LIOVd9iNCGcFkMnEuYCYFmNjmdbvVHFecwjCcz+EdHrGpXHWDyFUn/qI/ReMB23nALF4IzLr8/ltc5lLeAHP/VDLKGYyObwKyyts7zXZG2NKr7UVkHe0ZOck+F8eEwF8a4j1S6q1Dq/az7PkaYMclnF/hdoedhTzekFBwh6+NdSls5t+SetlaITquqXWdUv1bqbGzsjY1iV8hGd/CE+IZ37RW6XA0SPP9n2Ps/0ye2xlooqP4F27uCgJS8tsvaCEj/DcD8l6My6zZe8+ZwUYH8gl9uvUNkxfZru+y7I/l2m9z6NuWyrhFJ41JE8bDAicPHUa6Q3OysSUXmsrFNWkicz3Trxs1euTXDSKTCay6DgAQDF8mA+2S7TA81iGqVTyJ58gi4otuQeVLzzvvBzmokNqt69qBXkwH9CQSE6S/CFO4H9yqxjBdAbmr7YErEXTJsYNV4waGbAWxQ3/GlK+NAZPiKcuXyHz8gPQHFlYRJWcw8LDXI93Eu7Dv4aAXK5+OwsAYFi+EkDo37YQMixbwThMM6dhGo1/2woZgkBDAChfHE0kJ9E3bpo2b/NrQ+bvdtFXr+OxMerXM/3aUEgRFBoCOB62ZCHTFS37kDpb4qdGqMtX9O8uAQCEvfeO9OA9hwSHhqx56dVvzgA0bM6c5o+wHlRXr3ttCiBJ5YQMxehRnNcfygSLhgAA6rmzFelDUX1D84RMqHW8g+YdUKtl6qzWyvs/wT6mKMEhQfB8mR3IbG5+4WXq0hU87p6IL9fL+vb2vU7q/EXd5DdgtZZITorcuxuLiODCUonfCaJ+iFG0UqnZ/a1i+DCorWn66yvmbTk+HdOBkDlnR9OYcUwP9OTjkXt2SgLyB8HVD7WAkHH1p8Y1nwEAiB4pYfPf8uIcGbKwyLhyFXuinjJjbPjyJdLBjH4iKDVkxXIgz7BsBbxdZj1uoY9y4jjF0CFYRLjru5DRSOYdNm//ljxxquVkzwXzXO2iSPhM8GqIgaJN23JMaz+DtXXAuhopHzhAMWwI3qUz3jEBj43FoqOQTg+rqlFNDV1WTh7MJw8XILOZ+cM6RKqnT1VmTpRydPqb4NaQFWQwmjdvtRzIo85fdKe8rFdPefpQ1cQM6RnWwCAADdmA1VoyN48qOQ+1NbCiElZrUVMTFhmJx8fhnTvhHROIXj0VQ9Pwzp34tjS0EJKGJIKT4JrbSwgRSUMSviJpSMJXJA1J+IqkIQlfkTQk4SuShiR8RdKQhK/8XwAAAP//m8ty+1FQuRMAAAAASUVORK5CYII=\n", + "text/plain": [ + "<IPython.core.display.Image object>" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Image('./imgs/multi_loss.PNG')" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:49:09.943819Z", + "start_time": "2023-03-01T14:49:09.886836Z" + }, + "scrolled": false + }, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-9-1ef236403047>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;31m# RuntimeError: Trying to backward through the graph a second time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 306\u001b[0m inputs=inputs)\n\u001b[0;32m--> 307\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 154\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 155\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 156\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward." + ] + } + ], + "source": [ + "a = Variable(torch.rand(1, 4), requires_grad=True)\n", + "b = a**2\n", + "c = b*2\n", + "\n", + "d = c.mean()\n", + "e = c.sum()\n", + "\n", + "\n", + "d.backward()\n", + "\n", + "# RuntimeError: Trying to backward through the graph a second time\n", + "e.backward()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- when we do d.backward(), that is fine. \n", + "- After this computation, the parts of the graph that calculate **d will be freed by default to save memory**. \n", + "- So if we do e.backward(), the error message will pop up. In order to do e.backward(), we have to set the parameter retain_graph to True in d.backward(), i.e.," + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## retain graph" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:51:15.093632Z", + "start_time": "2023-03-01T14:51:15.089817Z" + } + }, + "outputs": [], + "source": [ + "a = Variable(torch.rand(1, 4), requires_grad=True)\n", + "b = a**2\n", + "c = b*2\n", + "\n", + "d = c.mean()\n", + "e = c.sum()\n", + "\n", + "\n", + "d.backward(retain_graph=True)\n", + "\n", + "# RuntimeError: Trying to backward through the graph a second time\n", + "e.backward()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## retain graph 下的梯度计算" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2023-03-01T14:54:20.850186Z", + "start_time": "2023-03-01T14:54:20.804778Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([1., 2., 3., 4.])\n", + "tensor([ 5., 10., 15., 20.])\n" + ] + } + ], + "source": [ + "a = Variable(torch.tensor([1., 2., 3., 4.]), requires_grad=True)\n", + "b = a**2\n", + "c = b*2\n", + "\n", + "# scalar\n", + "d = c.mean()\n", + "e = c.sum()\n", + "\n", + "\n", + "d.backward(retain_graph=True)\n", + "# tensor([1., 2., 3., 4.])\n", + "print(a.grad)\n", + "e.backward()\n", + "# 两次 backwward 累加\n", + "# tensor([ 5., 10., 15., 20.])\n", + "print(a.grad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- $d=\\frac{\\sum_i2a_i^2}4$\n", + " - $\\frac{\\partial d}{\\partial a_i}=a_i$\n", + "- $e=\\sum_i2a_i^2$\n", + " - $\\frac{\\partial e}{\\partial a_i}=4a_i$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## multi loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)\n", + "l1.backward(retain_graph=True)\n", + "l2.backward() # now the graph is freed, and next process of batch gradient descent is ready\n", + "\n", + "optimizer.step() # update the network parameters" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} |
