summaryrefslogtreecommitdiff
path: root/fine_tune/bert
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-09-01 22:39:16 +0800
committerzhang <zch921005@126.com>2022-09-01 22:39:16 +0800
commita1930ed563d5a3905c9504dbcff2fb00653233da (patch)
tree33ac93fab92f157de6ad1a91515ec9fcc9985e43 /fine_tune/bert
parent257cbbd2270e9a8756798e256a0bbd29c0fd83db (diff)
residual connection
Diffstat (limited to 'fine_tune/bert')
-rw-r--r--fine_tune/bert/tutorials/06_attn_01.ipynb121
-rw-r--r--fine_tune/bert/tutorials/07_add_norm_residual_conn.ipynb778
2 files changed, 899 insertions, 0 deletions
diff --git a/fine_tune/bert/tutorials/06_attn_01.ipynb b/fine_tune/bert/tutorials/06_attn_01.ipynb
index e933ea9..5bb5796 100644
--- a/fine_tune/bert/tutorials/06_attn_01.ipynb
+++ b/fine_tune/bert/tutorials/06_attn_01.ipynb
@@ -670,6 +670,127 @@
"source": [
"attn_emb.shape"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-29T13:20:49.611750Z",
+ "start_time": "2022-08-29T13:20:49.607269Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 2.3786e+00, -9.7945e-02, -2.8436e-01, ..., 9.6968e-02,\n",
+ " -1.8537e-01, 2.1132e-01],\n",
+ " [ 9.8886e-02, 2.2942e-01, -3.9613e-01, ..., 2.0846e-01,\n",
+ " -1.1224e-01, -2.1420e-01],\n",
+ " [-6.3356e-01, 1.1083e+00, -5.2859e-04, ..., -3.1870e-01,\n",
+ " 7.2633e-02, -1.0239e-01],\n",
+ " ...,\n",
+ " [-7.9659e-01, -4.9400e-01, -4.9216e-02, ..., -3.6446e-01,\n",
+ " 3.1565e-01, -7.1713e-01],\n",
+ " [-9.2551e-01, -5.0348e-01, -1.0398e-01, ..., -2.1418e-01,\n",
+ " 1.6604e-01, -5.9637e-01],\n",
+ " [ 1.2915e-01, -7.5900e-03, 1.8397e-01, ..., 2.6980e-01,\n",
+ " 1.3651e-01, 1.9180e-01]], grad_fn=<AddBackward0>)"
+ ]
+ },
+ "execution_count": 69,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "V_first_head_first_layer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 5. all"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-29T13:13:35.213456Z",
+ "start_time": "2022-08-29T13:13:35.195874Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "Q_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.query.weight.T \\\n",
+ " + model.encoder.layer[0].attention.self.query.bias\n",
+ "K_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.key.weight.T \\\n",
+ " + model.encoder.layer[0].attention.self.key.bias\n",
+ "V_first_layer = emb_output[0] @ model.encoder.layer[0].attention.self.value.weight.T \\\n",
+ " + model.encoder.layer[0].attention.self.value.bias"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-29T13:13:44.657849Z",
+ "start_time": "2022-08-29T13:13:44.639225Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "scores = torch.nn.Softmax(dim=-1)(Q_first_layer @ K_first_layer.T / math.sqrt(64))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 68,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-08-29T13:20:43.532343Z",
+ "start_time": "2022-08-29T13:20:43.527666Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 2.3786e+00, -9.7945e-02, -2.8436e-01, ..., 9.6968e-02,\n",
+ " -1.8537e-01, 2.1132e-01],\n",
+ " [-6.2984e-01, 1.1037e+00, -2.0764e-04, ..., -3.1767e-01,\n",
+ " 7.2197e-02, -1.0126e-01],\n",
+ " [ 3.7932e-02, 2.9540e-01, -3.6121e-01, ..., 1.6694e-01,\n",
+ " -9.9637e-02, -2.0356e-01],\n",
+ " ...,\n",
+ " [-7.0623e-01, -3.0113e-01, 9.4959e-02, ..., -1.7217e-01,\n",
+ " 1.8647e-01, -4.4743e-01],\n",
+ " [ 6.2047e-02, -4.0626e-02, 1.6757e-01, ..., 2.2690e-01,\n",
+ " 1.4684e-01, 1.3030e-01],\n",
+ " [ 2.3748e+00, -9.7801e-02, -2.8357e-01, ..., 9.7254e-02,\n",
+ " -1.8481e-01, 2.1127e-01]], grad_fn=<MmBackward0>)"
+ ]
+ },
+ "execution_count": 68,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "scores @ V_first_layer[:, :64]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/fine_tune/bert/tutorials/07_add_norm_residual_conn.ipynb b/fine_tune/bert/tutorials/07_add_norm_residual_conn.ipynb
new file mode 100644
index 0000000..a4923b6
--- /dev/null
+++ b/fine_tune/bert/tutorials/07_add_norm_residual_conn.ipynb
@@ -0,0 +1,778 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:15:38.735309Z",
+ "start_time": "2022-09-01T14:15:32.756820Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from transformers.models.bert import BertModel, BertTokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:15:38.739820Z",
+ "start_time": "2022-09-01T14:15:38.737370Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model_name = 'bert-base-uncased'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:15:49.089439Z",
+ "start_time": "2022-09-01T14:15:38.742338Z"
+ },
+ "collapsed": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']\n",
+ "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer = BertTokenizer.from_pretrained(model_name)\n",
+ "model = BertModel.from_pretrained(model_name, output_hidden_states=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T13:51:06.838060Z",
+ "start_time": "2022-09-01T13:51:06.833922Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BertConfig {\n",
+ " \"_name_or_path\": \"bert-base-uncased\",\n",
+ " \"architectures\": [\n",
+ " \"BertForMaskedLM\"\n",
+ " ],\n",
+ " \"attention_probs_dropout_prob\": 0.1,\n",
+ " \"classifier_dropout\": null,\n",
+ " \"gradient_checkpointing\": false,\n",
+ " \"hidden_act\": \"gelu\",\n",
+ " \"hidden_dropout_prob\": 0.1,\n",
+ " \"hidden_size\": 768,\n",
+ " \"initializer_range\": 0.02,\n",
+ " \"intermediate_size\": 3072,\n",
+ " \"layer_norm_eps\": 1e-12,\n",
+ " \"max_position_embeddings\": 512,\n",
+ " \"model_type\": \"bert\",\n",
+ " \"num_attention_heads\": 12,\n",
+ " \"num_hidden_layers\": 12,\n",
+ " \"pad_token_id\": 0,\n",
+ " \"position_embedding_type\": \"absolute\",\n",
+ " \"transformers_version\": \"4.11.2\",\n",
+ " \"type_vocab_size\": 2,\n",
+ " \"use_cache\": true,\n",
+ " \"vocab_size\": 30522\n",
+ "}"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# \"intermediate_size\": 3072,\n",
+ "model.config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T13:51:14.795848Z",
+ "start_time": "2022-09-01T13:51:14.790686Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (2): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (3): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (4): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (5): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (6): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (7): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (8): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (9): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (10): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (11): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (pooler): BertPooler(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (activation): Tanh()\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- BertLayer\n",
+ " - attention: BertAttention\n",
+ " - self: BertSelfAttention\n",
+ " - output: BertSelfOutput\n",
+ " - intermediate: BertIntermediate, 768=>4*768\n",
+ " - output: BertOutput, 4*768 => 768"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:27:52.643782Z",
+ "start_time": "2022-09-01T14:27:52.637766Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "test_sent = 'this is a test sentence'\n",
+ "\n",
+ "model_input = tokenizer(test_sent, return_tensors='pt')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T13:54:55.708402Z",
+ "start_time": "2022-09-01T13:54:55.705699Z"
+ }
+ },
+ "source": [
+ "## 1. model output "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:27:54.263945Z",
+ "start_time": "2022-09-01T14:27:53.832074Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model.eval()\n",
+ "with torch.no_grad():\n",
+ " output = model(**model_input)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:27:57.269504Z",
+ "start_time": "2022-09-01T14:27:57.233630Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[ 0.1686, -0.2858, -0.3261, ..., -0.0276, 0.0383, 0.1640],\n",
+ " [-0.6485, 0.6739, -0.0932, ..., 0.4475, 0.6696, 0.1820],\n",
+ " [-0.6270, -0.0633, -0.3143, ..., 0.3427, 0.4636, 0.4594],\n",
+ " ...,\n",
+ " [ 0.6010, -0.6970, -0.2001, ..., 0.2960, 0.2060, -1.7181],\n",
+ " [ 0.8323, 0.2878, 0.0021, ..., 0.2628, -1.1310, -1.2708],\n",
+ " [-0.1481, -0.2948, -0.1690, ..., -0.5009, 0.2544, -0.0700]]])"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# embeddings\n",
+ "output[2][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:27:58.920132Z",
+ "start_time": "2022-09-01T14:27:58.910943Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[ 0.1556, -0.0080, -0.0707, ..., 0.0786, 0.0213, 0.0616],\n",
+ " [-0.5333, 0.5799, 0.1044, ..., 0.0241, 0.4888, 0.0161],\n",
+ " [-1.0609, -0.3058, -0.5043, ..., 0.1874, 0.2874, 0.4032],\n",
+ " ...,\n",
+ " [ 0.8206, -0.6656, -0.7054, ..., 0.1347, 0.1117, -1.9040],\n",
+ " [ 1.1128, 0.6603, -0.1509, ..., 0.3253, -1.0006, -1.9106],\n",
+ " [-0.0736, 0.0346, 0.0376, ..., -0.4506, 0.6585, -0.0502]]])"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# first bert layer output\n",
+ "output[2][1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. from scratch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- BertLayer\n",
+ " - attention: BertAttention\n",
+ " - self: BertSelfAttention\n",
+ " - output: BertSelfOutput\n",
+ " - intermediate: BertIntermediate, 768=>4\\*768\n",
+ " - output: BertOutput"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:28:48.054122Z",
+ "start_time": "2022-09-01T14:28:48.051525Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "embeddings = output[2][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:28:44.503060Z",
+ "start_time": "2022-09-01T14:28:44.500257Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "layer = model.encoder.layer[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T13:58:51.899111Z",
+ "start_time": "2022-09-01T13:58:51.895787Z"
+ }
+ },
+ "source": [
+ "### 2.1 第一次 add & norm,发生在 mha 内部"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:28:49.159187Z",
+ "start_time": "2022-09-01T14:28:49.154735Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "mha_output = layer.attention.self(embeddings)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:28:52.210980Z",
+ "start_time": "2022-09-01T14:28:52.204900Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(tensor([[[ 0.2979, 0.0801, -0.0037, ..., -0.0142, 0.1290, 0.0828],\n",
+ " [ 0.3935, 0.1356, -0.0920, ..., 0.0211, 0.1677, 0.0011],\n",
+ " [ 0.1696, 0.1449, -0.1039, ..., 0.1604, 0.2172, 0.0310],\n",
+ " ...,\n",
+ " [-0.0617, 0.1968, -0.0669, ..., 0.1126, 0.1933, -0.0204],\n",
+ " [-0.2835, 0.1495, -0.0021, ..., 0.0973, 0.1865, -0.0636],\n",
+ " [ 0.2575, 0.1120, -0.1008, ..., 0.0175, 0.1508, 0.0878]]],\n",
+ " grad_fn=<ViewBackward0>),)"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mha_output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:29:17.700787Z",
+ "start_time": "2022-09-01T14:29:17.695502Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "attn_output = layer.attention.output(mha_output[0], embeddings)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 2.2 第一次 add & norm,发生在 mlp 内部"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:29:54.742326Z",
+ "start_time": "2022-09-01T14:29:54.737498Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "mlp1 = layer.intermediate(attn_output)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:29:58.524652Z",
+ "start_time": "2022-09-01T14:29:58.519921Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 7, 3072])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mlp1.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:30:12.269029Z",
+ "start_time": "2022-09-01T14:30:12.259448Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "mlp2 = layer.output(mlp1, attn_output)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-09-01T14:30:13.372805Z",
+ "start_time": "2022-09-01T14:30:13.366902Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[ 0.1556, -0.0080, -0.0707, ..., 0.0786, 0.0213, 0.0616],\n",
+ " [-0.5333, 0.5799, 0.1044, ..., 0.0241, 0.4888, 0.0161],\n",
+ " [-1.0609, -0.3058, -0.5043, ..., 0.1874, 0.2874, 0.4032],\n",
+ " ...,\n",
+ " [ 0.8206, -0.6656, -0.7054, ..., 0.1347, 0.1117, -1.9040],\n",
+ " [ 1.1128, 0.6603, -0.1509, ..., 0.3253, -1.0006, -1.9106],\n",
+ " [-0.0736, 0.0346, 0.0376, ..., -0.4506, 0.6585, -0.0502]]],\n",
+ " grad_fn=<NativeLayerNormBackward0>)"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mlp2"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}