summaryrefslogtreecommitdiff
path: root/fine_tune/bert/tutorials
diff options
context:
space:
mode:
Diffstat (limited to 'fine_tune/bert/tutorials')
-rw-r--r--fine_tune/bert/tutorials/02_no_grad_requires_grad.ipynb209
1 files changed, 209 insertions, 0 deletions
diff --git a/fine_tune/bert/tutorials/02_no_grad_requires_grad.ipynb b/fine_tune/bert/tutorials/02_no_grad_requires_grad.ipynb
new file mode 100644
index 0000000..4d65154
--- /dev/null
+++ b/fine_tune/bert/tutorials/02_no_grad_requires_grad.ipynb
@@ -0,0 +1,209 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### torch.no_grad() vs. param.requires_grad\n",
+ "\n",
+ "- torch.no_grad()\n",
+ " - 定义了一个上下文管理器,隐式地不进行梯度更新,不会改变 requires_grad\n",
+ " - 适用于 eval 阶段,或 model forward 的过程中某些模块不更新梯度的模块(此时这些模块仅进行特征提取(前向计算),不反向更新)\n",
+ "- param.requires_grad \n",
+ " - 显式地 frozen 掉一些module(layer)的梯度更新\n",
+ " - layer/module 级别,\n",
+ " - 可能会更灵活"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-28T15:06:50.866453Z",
+ "start_time": "2022-06-28T15:06:50.863925Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import BertModel\n",
+ "import torch\n",
+ "from torch import nn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-28T15:06:13.198197Z",
+ "start_time": "2022-06-28T15:06:13.195806Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model_name = 'bert-base-uncased'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-28T15:10:34.138893Z",
+ "start_time": "2022-06-28T15:10:30.552454Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']\n",
+ "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "bert = BertModel.from_pretrained(model_name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-28T15:08:11.400841Z",
+ "start_time": "2022-06-28T15:08:11.397764Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def calc_learnable_params(model):\n",
+ " total_param = 0\n",
+ " for name, param in model.named_parameters():\n",
+ " if param.requires_grad:\n",
+ " total_param += param.numel()\n",
+ " return total_param"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-28T15:10:35.508407Z",
+ "start_time": "2022-06-28T15:10:35.502720Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "109482240"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "calc_learnable_params(bert)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-28T15:21:30.906385Z",
+ "start_time": "2022-06-28T15:21:30.902666Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "109482240\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " print(calc_learnable_params(bert))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-28T15:22:11.486203Z",
+ "start_time": "2022-06-28T15:22:11.477806Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "for name, param in bert.named_parameters():\n",
+ " if param.requires_grad:\n",
+ " param.requires_grad = False"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-06-28T15:22:16.696133Z",
+ "start_time": "2022-06-28T15:22:16.690878Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "calc_learnable_params(bert)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}