diff options
| -rw-r--r-- | fine_tune/bert/tutorials/09_masked_lm.ipynb | 1862 | ||||
| -rw-r--r-- | fine_tune/bert/tutorials/09_mlm.py | 33 |
2 files changed, 1895 insertions, 0 deletions
diff --git a/fine_tune/bert/tutorials/09_masked_lm.ipynb b/fine_tune/bert/tutorials/09_masked_lm.ipynb new file mode 100644 index 0000000..ffd1211 --- /dev/null +++ b/fine_tune/bert/tutorials/09_masked_lm.ipynb @@ -0,0 +1,1862 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:50:16.399172Z", + "start_time": "2022-10-23T11:50:07.776148Z" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. model load and data preprocessing" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:50:21.521103Z", + "start_time": "2022-10-23T11:50:21.517814Z" + } + }, + "outputs": [], + "source": [ + "model_type = 'bert-base-uncased'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:50:42.915776Z", + "start_time": "2022-10-23T11:50:24.836326Z" + }, + "collapsed": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", + "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "tokenizer = BertTokenizer.from_pretrained(model_type)\n", + "bert = BertModel.from_pretrained(model_type)\n", + "mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:50:48.950455Z", + "start_time": "2022-10-23T11:50:48.938428Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bert" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:50:53.507897Z", + "start_time": "2022-10-23T11:50:53.500990Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "BertForMaskedLM(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (cls): BertOnlyMLMHead(\n", + " (predictions): BertLMPredictionHead(\n", + " (transform): BertPredictionHeadTransform(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " )\n", + " (decoder): Linear(in_features=768, out_features=30522, bias=True)\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mlm" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:52:15.652759Z", + "start_time": "2022-10-23T11:52:15.650142Z" + } + }, + "outputs": [], + "source": [ + "text = (\"After Abraham Lincoln won the November 1860 presidential \"\n", + " \"election on an anti-slavery platform, an initial seven \"\n", + " \"slave states declared their secession from the country \"\n", + " \"to form the Confederacy. War broke out in April 1861 \"\n", + " \"when secessionist forces attacked Fort Sumter in South \"\n", + " \"Carolina, just over a month after Lincoln's \"\n", + " \"inauguration.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T10:39:03.903713Z", + "start_time": "2022-10-23T10:39:03.894017Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"After Abraham Lincoln won the November 1860 presidential election on an anti-slavery platform, an initial seven slave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secessionist forces attacked Fort Sumter in South Carolina, just over a month after Lincoln's inauguration.\"" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:58:49.006117Z", + "start_time": "2022-10-23T11:58:49.001261Z" + } + }, + "outputs": [], + "source": [ + "inputs = tokenizer(text, return_tensors='pt')" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:52:29.260747Z", + "start_time": "2022-10-23T11:52:29.254577Z" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n", + " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n", + " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n", + " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n", + " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", + " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n", + " 1012, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T10:45:03.577003Z", + "start_time": "2022-10-23T10:45:03.573112Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 62])" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs['input_ids'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T10:41:55.118829Z", + "start_time": "2022-10-23T10:41:55.114199Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\"" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. masking" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:58:51.251735Z", + "start_time": "2022-10-23T11:58:51.249225Z" + } + }, + "outputs": [], + "source": [ + "inputs['labels'] = inputs['input_ids'].detach().clone()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T10:43:31.904097Z", + "start_time": "2022-10-23T10:43:31.899369Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n", + " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n", + " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n", + " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n", + " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", + " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n", + " 1012, 102]])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs['labels']" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:55:03.203643Z", + "start_time": "2022-10-23T11:55:03.200563Z" + } + }, + "outputs": [], + "source": [ + "mask = torch.rand(inputs['input_ids'].shape) < 0.15" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T10:44:20.474076Z", + "start_time": "2022-10-23T10:44:20.469553Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[False, False, True, False, False, False, False, False, False, True,\n", + " True, True, False, False, False, False, False, True, False, False,\n", + " False, False, False, False, False, False, False, False, True, False,\n", + " False, False, True, False, False, False, False, False, False, True,\n", + " False, False, False, False, True, False, False, False, False, False,\n", + " False, False, True, False, False, False, False, False, False, False,\n", + " True, False]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T10:44:29.746724Z", + "start_time": "2022-10-23T10:44:29.741795Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(11)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sum(mask[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T10:45:09.327613Z", + "start_time": "2022-10-23T10:45:09.323922Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "0.1774193548387097" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "11/62" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:58:55.364837Z", + "start_time": "2022-10-23T11:58:55.360049Z" + } + }, + "outputs": [], + "source": [ + "mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \\\n", + " * (inputs['input_ids'] != 101) \\\n", + " * (inputs['input_ids'] != 102)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:55:21.753212Z", + "start_time": "2022-10-23T11:55:21.748798Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[False, False, False, False, False, False, False, True, False, False,\n", + " False, False, False, False, False, False, False, False, True, False,\n", + " True, False, False, False, True, False, False, False, False, False,\n", + " False, False, False, True, False, False, False, False, False, True,\n", + " True, False, False, False, False, False, False, True, False, False,\n", + " True, False, True, False, False, False, True, False, False, False,\n", + " False, False]])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask_arr" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:55:30.201772Z", + "start_time": "2022-10-23T11:55:30.197165Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(11)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sum(mask_arr[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:58:57.715475Z", + "start_time": "2022-10-23T11:58:57.712828Z" + } + }, + "outputs": [], + "source": [ + "selection = torch.flatten(mask_arr[0].nonzero()).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:55:40.122327Z", + "start_time": "2022-10-23T11:55:40.118662Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[7, 18, 20, 24, 33, 39, 40, 47, 50, 52, 56]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "selection" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:55:53.650978Z", + "start_time": "2022-10-23T11:55:53.646182Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'unk_token': '[UNK]',\n", + " 'sep_token': '[SEP]',\n", + " 'pad_token': '[PAD]',\n", + " 'cls_token': '[CLS]',\n", + " 'mask_token': '[MASK]'}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.special_tokens_map" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:56:08.158223Z", + "start_time": "2022-10-23T11:56:08.154330Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "103" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.vocab['[MASK]']" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:59:01.275756Z", + "start_time": "2022-10-23T11:59:01.272924Z" + } + }, + "outputs": [], + "source": [ + "inputs['input_ids'][0, selection] = 103" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:59:02.884794Z", + "start_time": "2022-10-23T11:59:02.878431Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 103, 4883, 2602,\n", + " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n", + " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n", + " 1996, 18179, 1012, 2162, 3631, 103, 1999, 103, 6863, 2043,\n", + " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", + " 1010, 103, 2058, 1037, 3204, 2044, 5367, 103, 1055, 17331,\n", + " 1012, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n", + " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n", + " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n", + " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n", + " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", + " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n", + " 1012, 102]])}" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:59:07.203048Z", + "start_time": "2022-10-23T11:59:07.197865Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'[CLS] after abraham lincoln won the november [MASK] presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke [MASK] in [MASK] 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , [MASK] over a month after lincoln [MASK] s inauguration . [SEP]'" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:59:09.922859Z", + "start_time": "2022-10-23T11:59:09.918001Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\"" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. forward and calculate loss" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:00:19.169755Z", + "start_time": "2022-10-23T12:00:18.879776Z" + } + }, + "outputs": [], + "source": [ + "mlm.eval()\n", + "with torch.no_grad():\n", + " output = mlm(**inputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:00:26.763079Z", + "start_time": "2022-10-23T12:00:26.759090Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "odict_keys(['loss', 'logits', 'hidden_states'])" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:00:33.102388Z", + "start_time": "2022-10-23T12:00:33.096135Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n", + " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n", + " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n", + " ...,\n", + " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n", + " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n", + " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output.logits" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:00:36.403750Z", + "start_time": "2022-10-23T12:00:36.399118Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.5636)" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output.loss" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:01:00.247788Z", + "start_time": "2022-10-23T12:01:00.243721Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "13" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(output['hidden_states'])" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:01:58.138428Z", + "start_time": "2022-10-23T12:01:58.132733Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[-0.3974, 0.0558, -0.3905, ..., -0.2893, -0.1375, 0.3952],\n", + " [-0.6964, -0.0369, 0.2051, ..., -0.4537, 0.1505, 0.5892],\n", + " [-0.6722, 1.0108, -0.7013, ..., -0.6308, -0.2771, 0.3940],\n", + " ...,\n", + " [-0.5778, 0.5753, -0.5293, ..., -0.7302, -0.5109, 1.3849],\n", + " [ 0.5438, 0.0137, -0.3779, ..., 0.1812, -0.6194, -0.1336],\n", + " [-0.4519, -0.3448, -1.0264, ..., -0.1259, -0.4856, 0.3235]]])" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output['hidden_states'][-1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. from scratch" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:02:16.279415Z", + "start_time": "2022-10-23T12:02:16.213720Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n", + " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n", + " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n", + " ...,\n", + " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n", + " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n", + " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]],\n", + " grad_fn=<AddBackward0>)" + ] + }, + "execution_count": 41, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mlm.cls(output['hidden_states'][-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:02:24.413213Z", + "start_time": "2022-10-23T12:02:24.408190Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n", + " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n", + " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n", + " ...,\n", + " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n", + " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n", + " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output.logits" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:02:46.240782Z", + "start_time": "2022-10-23T12:02:46.238325Z" + } + }, + "outputs": [], + "source": [ + "last_hidden_state = output['hidden_states'][-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:04:31.581726Z", + "start_time": "2022-10-23T12:04:31.577795Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 62, 768])" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "last_hidden_state.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:06:02.323620Z", + "start_time": "2022-10-23T12:06:02.246346Z" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 62, 768])\n", + "torch.Size([1, 62, 30522])\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n", + " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n", + " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n", + " ...,\n", + " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n", + " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n", + " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mlm.eval()\n", + "with torch.no_grad():\n", + " transformed = mlm.cls.predictions.transform(last_hidden_state)\n", + " print(transformed.shape)\n", + " logits = mlm.cls.predictions.decoder(transformed)\n", + " print(logits.shape)\n", + "logits" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:08:21.847190Z", + "start_time": "2022-10-23T12:08:21.842330Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.5636)" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output.loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. loss and translate" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:08:26.111715Z", + "start_time": "2022-10-23T12:08:26.108950Z" + } + }, + "outputs": [], + "source": [ + "ce = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:08:30.584319Z", + "start_time": "2022-10-23T12:08:30.579982Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 62, 30522])" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "logits.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:08:32.011230Z", + "start_time": "2022-10-23T12:08:32.007146Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 62])" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs['labels'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:08:38.754098Z", + "start_time": "2022-10-23T12:08:38.748962Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([62])" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs['labels'][0].view(-1).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:08:43.837921Z", + "start_time": "2022-10-23T12:08:43.826163Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.5636)" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ce(logits[0], inputs['labels'][0].view(-1))" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:45:20.920638Z", + "start_time": "2022-10-23T11:45:20.914885Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 1012, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n", + " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2035, 1996, 2698,\n", + " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2586, 2000, 3693,\n", + " 1996, 18179, 1012, 2162, 3631, 2034, 1999, 2258, 6863, 2043,\n", + " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", + " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n", + " 1012, 3519])" + ] + }, + "execution_count": 89, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.argmax(logits[0], dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T11:45:59.718848Z", + "start_time": "2022-10-23T11:45:59.714203Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"[CLS] after abraham lincoln won the november 1860 [MASK] election [MASK] an anti - slavery platform , [MASK] [MASK] seven slave [MASK] declared their secession from the [MASK] to [MASK] the confederacy . war broke out [MASK] april 1861 when [MASK] ##ist forces attacked fort sum [MASK] in south [MASK] , just over a month after lincoln ' s [MASK] . [SEP]\"" + ] + }, + "execution_count": 92, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:09:36.590470Z", + "start_time": "2022-10-23T12:09:36.584581Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\". after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in december 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . s\"" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0], dim=1)))" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "ExecuteTime": { + "end_time": "2022-10-23T12:09:47.275769Z", + "start_time": "2022-10-23T12:09:47.271126Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\"" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fine_tune/bert/tutorials/09_mlm.py b/fine_tune/bert/tutorials/09_mlm.py new file mode 100644 index 0000000..0177c3f --- /dev/null +++ b/fine_tune/bert/tutorials/09_mlm.py @@ -0,0 +1,33 @@ + +import torch +from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM + + +model_type = 'bert-base-uncased' + +tokenizer = BertTokenizer.from_pretrained(model_type) +bert = BertModel.from_pretrained(model_type) +mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True) + + +text = ("After Abraham Lincoln won the November 1860 presidential " + "election on an anti-slavery platform, an initial seven " + "slave states declared their secession from the country " + "to form the Confederacy. War broke out in April 1861 " + "when secessionist forces attacked Fort Sumter in South " + "Carolina, just over a month after Lincoln's " + "inauguration.") + +inputs = tokenizer(text, return_tensors='pt') +inputs['labels'] = inputs['input_ids'].detach().clone() + +mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \ + * (inputs['input_ids'] != 101) \ + * (inputs['input_ids'] != 102) +selection = torch.flatten(mask_arr[0].nonzero()).tolist() +inputs['input_ids'][0, selection] = 103 + +mlm.eval() +with torch.no_grad(): + output = mlm(**inputs) +print()
\ No newline at end of file |
