From a1930ed563d5a3905c9504dbcff2fb00653233da Mon Sep 17 00:00:00 2001 From: zhang Date: Thu, 1 Sep 2022 22:39:16 +0800 Subject: residual connection --- basics/python/circular/a.py | 2 +- fine_tune/bert/tutorials/06_attn_01.ipynb | 121 ++++ .../bert/tutorials/07_add_norm_residual_conn.ipynb | 778 +++++++++++++++++++++ learn_torch/basics/add_norm.py | 17 + learn_torch/basics/mha.py | 16 + 5 files changed, 933 insertions(+), 1 deletion(-) create mode 100644 fine_tune/bert/tutorials/07_add_norm_residual_conn.ipynb create mode 100644 learn_torch/basics/add_norm.py create mode 100644 learn_torch/basics/mha.py diff --git a/basics/python/circular/a.py b/basics/python/circular/a.py index b32c70d..1b8a255 100644 --- a/basics/python/circular/a.py +++ b/basics/python/circular/a.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: class A: - def foo(self, b: 'B'): + def foo(self, b: B): return 2*b.foo() def t(self): diff --git a/fine_tune/bert/tutorials/06_attn_01.ipynb b/fine_tune/bert/tutorials/06_attn_01.ipynb index e933ea9..5bb5796 100644 --- a/fine_tune/bert/tutorials/06_attn_01.ipynb +++ b/fine_tune/bert/tutorials/06_attn_01.ipynb @@ -670,6 +670,127 @@ "source": [ "attn_emb.shape" ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-29T13:20:49.611750Z", + "start_time": "2022-08-29T13:20:49.607269Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 2.3786e+00, -9.7945e-02, -2.8436e-01, ..., 9.6968e-02,\n", + " -1.8537e-01, 2.1132e-01],\n", + " [ 9.8886e-02, 2.2942e-01, -3.9613e-01, ..., 2.0846e-01,\n", + " -1.1224e-01, -2.1420e-01],\n", + " [-6.3356e-01, 1.1083e+00, -5.2859e-04, ..., -3.1870e-01,\n", + " 7.2633e-02, -1.0239e-01],\n", + " ...,\n", + " [-7.9659e-01, -4.9400e-01, -4.9216e-02, ..., -3.6446e-01,\n", + " 3.1565e-01, -7.1713e-01],\n", + " [-9.2551e-01, -5.0348e-01, -1.0398e-01, ..., -2.1418e-01,\n", + " 1.6604e-01, -5.9637e-01],\n", + " [ 1.2915e-01, -7.5900e-03, 1.8397e-01, ..., 2.6980e-01,\n", + " 1.3651e-01, 1.9180e-01]], grad_fn=)" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V_first_head_first_layer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. all" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-29T13:13:35.213456Z", + "start_time": "2022-08-29T13:13:35.195874Z" + } + }, + "outputs": [], + "source": [ + "Q_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.query.weight.T \\\n", + " + model.encoder.layer[0].attention.self.query.bias\n", + "K_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.key.weight.T \\\n", + " + model.encoder.layer[0].attention.self.key.bias\n", + "V_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.value.weight.T \\\n", + " + model.encoder.layer[0].attention.self.value.bias" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-29T13:13:44.657849Z", + "start_time": "2022-08-29T13:13:44.639225Z" + } + }, + "outputs": [], + "source": [ + "scores = torch.nn.Softmax(dim=-1)(Q_first_layer @ K_first_layer.T / math.sqrt(64))" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "ExecuteTime": { + "end_time": "2022-08-29T13:20:43.532343Z", + "start_time": "2022-08-29T13:20:43.527666Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 2.3786e+00, -9.7945e-02, -2.8436e-01, ..., 9.6968e-02,\n", + " -1.8537e-01, 2.1132e-01],\n", + " [-6.2984e-01, 1.1037e+00, -2.0764e-04, ..., -3.1767e-01,\n", + " 7.2197e-02, -1.0126e-01],\n", + " [ 3.7932e-02, 2.9540e-01, -3.6121e-01, ..., 1.6694e-01,\n", + " -9.9637e-02, -2.0356e-01],\n", + " ...,\n", + " [-7.0623e-01, -3.0113e-01, 9.4959e-02, ..., -1.7217e-01,\n", + " 1.8647e-01, -4.4743e-01],\n", + " [ 6.2047e-02, -4.0626e-02, 1.6757e-01, ..., 2.2690e-01,\n", + " 1.4684e-01, 1.3030e-01],\n", + " [ 2.3748e+00, -9.7801e-02, -2.8357e-01, ..., 9.7254e-02,\n", + " -1.8481e-01, 2.1127e-01]], grad_fn=)" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores @ V_first_layer[:, :64]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/fine_tune/bert/tutorials/07_add_norm_residual_conn.ipynb b/fine_tune/bert/tutorials/07_add_norm_residual_conn.ipynb new file mode 100644 index 0000000..a4923b6 --- /dev/null +++ b/fine_tune/bert/tutorials/07_add_norm_residual_conn.ipynb @@ -0,0 +1,778 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:15:38.735309Z", + "start_time": "2022-09-01T14:15:32.756820Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from transformers.models.bert import BertModel, BertTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:15:38.739820Z", + "start_time": "2022-09-01T14:15:38.737370Z" + } + }, + "outputs": [], + "source": [ + "model_name = 'bert-base-uncased'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:15:49.089439Z", + "start_time": "2022-09-01T14:15:38.742338Z" + }, + "collapsed": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "tokenizer = BertTokenizer.from_pretrained(model_name)\n", + "model = BertModel.from_pretrained(model_name, output_hidden_states=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T13:51:06.838060Z", + "start_time": "2022-09-01T13:51:06.833922Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertConfig {\n", + " \"_name_or_path\": \"bert-base-uncased\",\n", + " \"architectures\": [\n", + " \"BertForMaskedLM\"\n", + " ],\n", + " \"attention_probs_dropout_prob\": 0.1,\n", + " \"classifier_dropout\": null,\n", + " \"gradient_checkpointing\": false,\n", + " \"hidden_act\": \"gelu\",\n", + " \"hidden_dropout_prob\": 0.1,\n", + " \"hidden_size\": 768,\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 3072,\n", + " \"layer_norm_eps\": 1e-12,\n", + " \"max_position_embeddings\": 512,\n", + " \"model_type\": \"bert\",\n", + " \"num_attention_heads\": 12,\n", + " \"num_hidden_layers\": 12,\n", + " \"pad_token_id\": 0,\n", + " \"position_embedding_type\": \"absolute\",\n", + " \"transformers_version\": \"4.11.2\",\n", + " \"type_vocab_size\": 2,\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 30522\n", + "}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# \"intermediate_size\": 3072,\n", + "model.config" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T13:51:14.795848Z", + "start_time": "2022-09-01T13:51:14.790686Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- BertLayer\n", + " - attention: BertAttention\n", + " - self: BertSelfAttention\n", + " - output: BertSelfOutput\n", + " - intermediate: BertIntermediate, 768=>4*768\n", + " - output: BertOutput, 4*768 => 768" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:27:52.643782Z", + "start_time": "2022-09-01T14:27:52.637766Z" + } + }, + "outputs": [], + "source": [ + "test_sent = 'this is a test sentence'\n", + "\n", + "model_input = tokenizer(test_sent, return_tensors='pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T13:54:55.708402Z", + "start_time": "2022-09-01T13:54:55.705699Z" + } + }, + "source": [ + "## 1. model output " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:27:54.263945Z", + "start_time": "2022-09-01T14:27:53.832074Z" + } + }, + "outputs": [], + "source": [ + "model.eval()\n", + "with torch.no_grad():\n", + " output = model(**model_input)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:27:57.269504Z", + "start_time": "2022-09-01T14:27:57.233630Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.1686, -0.2858, -0.3261, ..., -0.0276, 0.0383, 0.1640],\n", + " [-0.6485, 0.6739, -0.0932, ..., 0.4475, 0.6696, 0.1820],\n", + " [-0.6270, -0.0633, -0.3143, ..., 0.3427, 0.4636, 0.4594],\n", + " ...,\n", + " [ 0.6010, -0.6970, -0.2001, ..., 0.2960, 0.2060, -1.7181],\n", + " [ 0.8323, 0.2878, 0.0021, ..., 0.2628, -1.1310, -1.2708],\n", + " [-0.1481, -0.2948, -0.1690, ..., -0.5009, 0.2544, -0.0700]]])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# embeddings\n", + "output[2][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:27:58.920132Z", + "start_time": "2022-09-01T14:27:58.910943Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.1556, -0.0080, -0.0707, ..., 0.0786, 0.0213, 0.0616],\n", + " [-0.5333, 0.5799, 0.1044, ..., 0.0241, 0.4888, 0.0161],\n", + " [-1.0609, -0.3058, -0.5043, ..., 0.1874, 0.2874, 0.4032],\n", + " ...,\n", + " [ 0.8206, -0.6656, -0.7054, ..., 0.1347, 0.1117, -1.9040],\n", + " [ 1.1128, 0.6603, -0.1509, ..., 0.3253, -1.0006, -1.9106],\n", + " [-0.0736, 0.0346, 0.0376, ..., -0.4506, 0.6585, -0.0502]]])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# first bert layer output\n", + "output[2][1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. from scratch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- BertLayer\n", + " - attention: BertAttention\n", + " - self: BertSelfAttention\n", + " - output: BertSelfOutput\n", + " - intermediate: BertIntermediate, 768=>4\\*768\n", + " - output: BertOutput" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:28:48.054122Z", + "start_time": "2022-09-01T14:28:48.051525Z" + } + }, + "outputs": [], + "source": [ + "embeddings = output[2][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:28:44.503060Z", + "start_time": "2022-09-01T14:28:44.500257Z" + } + }, + "outputs": [], + "source": [ + "layer = model.encoder.layer[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T13:58:51.899111Z", + "start_time": "2022-09-01T13:58:51.895787Z" + } + }, + "source": [ + "### 2.1 第一次 add & norm,发生在 mha 内部" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:28:49.159187Z", + "start_time": "2022-09-01T14:28:49.154735Z" + } + }, + "outputs": [], + "source": [ + "mha_output = layer.attention.self(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:28:52.210980Z", + "start_time": "2022-09-01T14:28:52.204900Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[[ 0.2979, 0.0801, -0.0037, ..., -0.0142, 0.1290, 0.0828],\n", + " [ 0.3935, 0.1356, -0.0920, ..., 0.0211, 0.1677, 0.0011],\n", + " [ 0.1696, 0.1449, -0.1039, ..., 0.1604, 0.2172, 0.0310],\n", + " ...,\n", + " [-0.0617, 0.1968, -0.0669, ..., 0.1126, 0.1933, -0.0204],\n", + " [-0.2835, 0.1495, -0.0021, ..., 0.0973, 0.1865, -0.0636],\n", + " [ 0.2575, 0.1120, -0.1008, ..., 0.0175, 0.1508, 0.0878]]],\n", + " grad_fn=),)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mha_output" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:29:17.700787Z", + "start_time": "2022-09-01T14:29:17.695502Z" + } + }, + "outputs": [], + "source": [ + "attn_output = layer.attention.output(mha_output[0], embeddings)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 第一次 add & norm,发生在 mlp 内部" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:29:54.742326Z", + "start_time": "2022-09-01T14:29:54.737498Z" + } + }, + "outputs": [], + "source": [ + "mlp1 = layer.intermediate(attn_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:29:58.524652Z", + "start_time": "2022-09-01T14:29:58.519921Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 7, 3072])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mlp1.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:30:12.269029Z", + "start_time": "2022-09-01T14:30:12.259448Z" + } + }, + "outputs": [], + "source": [ + "mlp2 = layer.output(mlp1, attn_output)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2022-09-01T14:30:13.372805Z", + "start_time": "2022-09-01T14:30:13.366902Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.1556, -0.0080, -0.0707, ..., 0.0786, 0.0213, 0.0616],\n", + " [-0.5333, 0.5799, 0.1044, ..., 0.0241, 0.4888, 0.0161],\n", + " [-1.0609, -0.3058, -0.5043, ..., 0.1874, 0.2874, 0.4032],\n", + " ...,\n", + " [ 0.8206, -0.6656, -0.7054, ..., 0.1347, 0.1117, -1.9040],\n", + " [ 1.1128, 0.6603, -0.1509, ..., 0.3253, -1.0006, -1.9106],\n", + " [-0.0736, 0.0346, 0.0376, ..., -0.4506, 0.6585, -0.0502]]],\n", + " grad_fn=)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mlp2" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/learn_torch/basics/add_norm.py b/learn_torch/basics/add_norm.py new file mode 100644 index 0000000..6d733f7 --- /dev/null +++ b/learn_torch/basics/add_norm.py @@ -0,0 +1,17 @@ + +import torch +from transformers.models.bert import BertModel, BertTokenizer + + +if __name__ == '__main__': + model_name = 'bert-base-uncased' + tokenizer = BertTokenizer.from_pretrained(model_name) + model = BertModel.from_pretrained(model_name, output_hidden_states=True) + + test_sent = 'this is a test sentence' + + model_input = tokenizer(test_sent, return_tensors='pt') + model.eval() + with torch.no_grad(): + output = model(**model_input) + diff --git a/learn_torch/basics/mha.py b/learn_torch/basics/mha.py new file mode 100644 index 0000000..d9d392d --- /dev/null +++ b/learn_torch/basics/mha.py @@ -0,0 +1,16 @@ +import torch +from torch import nn + +if __name__ == '__main__': + + mha = nn.MultiheadAttention(embed_dim=768, num_heads=12, kdim=10, vdim=20) + + query = torch.randn(10, 1, 768) + key = torch.randn(5, 1, 10) + value = torch.randn(5, 1, 20) + + attn_output, attn_output_weights = mha(query, key, value) + print(attn_output.shape) + print(attn_output_weights.shape) + + -- cgit v1.2.3