diff options
| -rw-r--r-- | fine_tune/bert/tutorials/02_no_grad_requires_grad.ipynb | 209 |
1 files changed, 209 insertions, 0 deletions
diff --git a/fine_tune/bert/tutorials/02_no_grad_requires_grad.ipynb b/fine_tune/bert/tutorials/02_no_grad_requires_grad.ipynb new file mode 100644 index 0000000..4d65154 --- /dev/null +++ b/fine_tune/bert/tutorials/02_no_grad_requires_grad.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### torch.no_grad() vs. param.requires_grad\n", + "\n", + "- torch.no_grad()\n", + " - 定义了一个上下文管理器,隐式地不进行梯度更新,不会改变 requires_grad\n", + " - 适用于 eval 阶段,或 model forward 的过程中某些模块不更新梯度的模块(此时这些模块仅进行特征提取(前向计算),不反向更新)\n", + "- param.requires_grad \n", + " - 显式地 frozen 掉一些module(layer)的梯度更新\n", + " - layer/module 级别,\n", + " - 可能会更灵活" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-28T15:06:50.866453Z", + "start_time": "2022-06-28T15:06:50.863925Z" + } + }, + "outputs": [], + "source": [ + "from transformers import BertModel\n", + "import torch\n", + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-28T15:06:13.198197Z", + "start_time": "2022-06-28T15:06:13.195806Z" + } + }, + "outputs": [], + "source": [ + "model_name = 'bert-base-uncased'" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-28T15:10:34.138893Z", + "start_time": "2022-06-28T15:10:30.552454Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "bert = BertModel.from_pretrained(model_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-28T15:08:11.400841Z", + "start_time": "2022-06-28T15:08:11.397764Z" + } + }, + "outputs": [], + "source": [ + "def calc_learnable_params(model):\n", + " total_param = 0\n", + " for name, param in model.named_parameters():\n", + " if param.requires_grad:\n", + " total_param += param.numel()\n", + " return total_param" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-28T15:10:35.508407Z", + "start_time": "2022-06-28T15:10:35.502720Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "109482240" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "calc_learnable_params(bert)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-28T15:21:30.906385Z", + "start_time": "2022-06-28T15:21:30.902666Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "109482240\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " print(calc_learnable_params(bert))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-28T15:22:11.486203Z", + "start_time": "2022-06-28T15:22:11.477806Z" + } + }, + "outputs": [], + "source": [ + "for name, param in bert.named_parameters():\n", + " if param.requires_grad:\n", + " param.requires_grad = False" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2022-06-28T15:22:16.696133Z", + "start_time": "2022-06-28T15:22:16.690878Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "calc_learnable_params(bert)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} |
