summaryrefslogtreecommitdiff
path: root/fine_tune
diff options
context:
space:
mode:
Diffstat (limited to 'fine_tune')
-rw-r--r--fine_tune/bert/tutorials/09_masked_lm.ipynb1862
-rw-r--r--fine_tune/bert/tutorials/09_mlm.py33
2 files changed, 1895 insertions, 0 deletions
diff --git a/fine_tune/bert/tutorials/09_masked_lm.ipynb b/fine_tune/bert/tutorials/09_masked_lm.ipynb
new file mode 100644
index 0000000..ffd1211
--- /dev/null
+++ b/fine_tune/bert/tutorials/09_masked_lm.ipynb
@@ -0,0 +1,1862 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:50:16.399172Z",
+ "start_time": "2022-10-23T11:50:07.776148Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. model load and data preprocessing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:50:21.521103Z",
+ "start_time": "2022-10-23T11:50:21.517814Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "model_type = 'bert-base-uncased'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:50:42.915776Z",
+ "start_time": "2022-10-23T11:50:24.836326Z"
+ },
+ "collapsed": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']\n",
+ "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
+ "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer = BertTokenizer.from_pretrained(model_type)\n",
+ "bert = BertModel.from_pretrained(model_type)\n",
+ "mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:50:48.950455Z",
+ "start_time": "2022-10-23T11:50:48.938428Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (2): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (3): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (4): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (5): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (6): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (7): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (8): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (9): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (10): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (11): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (pooler): BertPooler(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (activation): Tanh()\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "bert"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:50:53.507897Z",
+ "start_time": "2022-10-23T11:50:53.500990Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BertForMaskedLM(\n",
+ " (bert): BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (2): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (3): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (4): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (5): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (6): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (7): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (8): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (9): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (10): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (11): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (cls): BertOnlyMLMHead(\n",
+ " (predictions): BertLMPredictionHead(\n",
+ " (transform): BertPredictionHeadTransform(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " )\n",
+ " (decoder): Linear(in_features=768, out_features=30522, bias=True)\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mlm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:52:15.652759Z",
+ "start_time": "2022-10-23T11:52:15.650142Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "text = (\"After Abraham Lincoln won the November 1860 presidential \"\n",
+ " \"election on an anti-slavery platform, an initial seven \"\n",
+ " \"slave states declared their secession from the country \"\n",
+ " \"to form the Confederacy. War broke out in April 1861 \"\n",
+ " \"when secessionist forces attacked Fort Sumter in South \"\n",
+ " \"Carolina, just over a month after Lincoln's \"\n",
+ " \"inauguration.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T10:39:03.903713Z",
+ "start_time": "2022-10-23T10:39:03.894017Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"After Abraham Lincoln won the November 1860 presidential election on an anti-slavery platform, an initial seven slave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secessionist forces attacked Fort Sumter in South Carolina, just over a month after Lincoln's inauguration.\""
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "text"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:58:49.006117Z",
+ "start_time": "2022-10-23T11:58:49.001261Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "inputs = tokenizer(text, return_tensors='pt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:52:29.260747Z",
+ "start_time": "2022-10-23T11:52:29.254577Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n",
+ " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n",
+ " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n",
+ " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n",
+ " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n",
+ " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n",
+ " 1012, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T10:45:03.577003Z",
+ "start_time": "2022-10-23T10:45:03.573112Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 62])"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs['input_ids'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T10:41:55.118829Z",
+ "start_time": "2022-10-23T10:41:55.114199Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\""
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. masking"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:58:51.251735Z",
+ "start_time": "2022-10-23T11:58:51.249225Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "inputs['labels'] = inputs['input_ids'].detach().clone()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T10:43:31.904097Z",
+ "start_time": "2022-10-23T10:43:31.899369Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n",
+ " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n",
+ " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n",
+ " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n",
+ " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n",
+ " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n",
+ " 1012, 102]])"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs['labels']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:55:03.203643Z",
+ "start_time": "2022-10-23T11:55:03.200563Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "mask = torch.rand(inputs['input_ids'].shape) < 0.15"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T10:44:20.474076Z",
+ "start_time": "2022-10-23T10:44:20.469553Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[False, False, True, False, False, False, False, False, False, True,\n",
+ " True, True, False, False, False, False, False, True, False, False,\n",
+ " False, False, False, False, False, False, False, False, True, False,\n",
+ " False, False, True, False, False, False, False, False, False, True,\n",
+ " False, False, False, False, True, False, False, False, False, False,\n",
+ " False, False, True, False, False, False, False, False, False, False,\n",
+ " True, False]])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T10:44:29.746724Z",
+ "start_time": "2022-10-23T10:44:29.741795Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(11)"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sum(mask[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T10:45:09.327613Z",
+ "start_time": "2022-10-23T10:45:09.323922Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.1774193548387097"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "11/62"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:58:55.364837Z",
+ "start_time": "2022-10-23T11:58:55.360049Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \\\n",
+ " * (inputs['input_ids'] != 101) \\\n",
+ " * (inputs['input_ids'] != 102)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:55:21.753212Z",
+ "start_time": "2022-10-23T11:55:21.748798Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[False, False, False, False, False, False, False, True, False, False,\n",
+ " False, False, False, False, False, False, False, False, True, False,\n",
+ " True, False, False, False, True, False, False, False, False, False,\n",
+ " False, False, False, True, False, False, False, False, False, True,\n",
+ " True, False, False, False, False, False, False, True, False, False,\n",
+ " True, False, True, False, False, False, True, False, False, False,\n",
+ " False, False]])"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mask_arr"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:55:30.201772Z",
+ "start_time": "2022-10-23T11:55:30.197165Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(11)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sum(mask_arr[0])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:58:57.715475Z",
+ "start_time": "2022-10-23T11:58:57.712828Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "selection = torch.flatten(mask_arr[0].nonzero()).tolist()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:55:40.122327Z",
+ "start_time": "2022-10-23T11:55:40.118662Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[7, 18, 20, 24, 33, 39, 40, 47, 50, 52, 56]"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "selection"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:55:53.650978Z",
+ "start_time": "2022-10-23T11:55:53.646182Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'unk_token': '[UNK]',\n",
+ " 'sep_token': '[SEP]',\n",
+ " 'pad_token': '[PAD]',\n",
+ " 'cls_token': '[CLS]',\n",
+ " 'mask_token': '[MASK]'}"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.special_tokens_map"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:56:08.158223Z",
+ "start_time": "2022-10-23T11:56:08.154330Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "103"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.vocab['[MASK]']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:59:01.275756Z",
+ "start_time": "2022-10-23T11:59:01.272924Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "inputs['input_ids'][0, selection] = 103"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:59:02.884794Z",
+ "start_time": "2022-10-23T11:59:02.878431Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 103, 4883, 2602,\n",
+ " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n",
+ " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n",
+ " 1996, 18179, 1012, 2162, 3631, 103, 1999, 103, 6863, 2043,\n",
+ " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n",
+ " 1010, 103, 2058, 1037, 3204, 2044, 5367, 103, 1055, 17331,\n",
+ " 1012, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
+ " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n",
+ " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n",
+ " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n",
+ " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n",
+ " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n",
+ " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n",
+ " 1012, 102]])}"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:59:07.203048Z",
+ "start_time": "2022-10-23T11:59:07.197865Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'[CLS] after abraham lincoln won the november [MASK] presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke [MASK] in [MASK] 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , [MASK] over a month after lincoln [MASK] s inauguration . [SEP]'"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:59:09.922859Z",
+ "start_time": "2022-10-23T11:59:09.918001Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\""
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. forward and calculate loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:00:19.169755Z",
+ "start_time": "2022-10-23T12:00:18.879776Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "mlm.eval()\n",
+ "with torch.no_grad():\n",
+ " output = mlm(**inputs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:00:26.763079Z",
+ "start_time": "2022-10-23T12:00:26.759090Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "odict_keys(['loss', 'logits', 'hidden_states'])"
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "output.keys()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:00:33.102388Z",
+ "start_time": "2022-10-23T12:00:33.096135Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n",
+ " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n",
+ " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n",
+ " ...,\n",
+ " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n",
+ " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n",
+ " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])"
+ ]
+ },
+ "execution_count": 36,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "output.logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:00:36.403750Z",
+ "start_time": "2022-10-23T12:00:36.399118Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.5636)"
+ ]
+ },
+ "execution_count": 37,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "output.loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:01:00.247788Z",
+ "start_time": "2022-10-23T12:01:00.243721Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "13"
+ ]
+ },
+ "execution_count": 39,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(output['hidden_states'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:01:58.138428Z",
+ "start_time": "2022-10-23T12:01:58.132733Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[-0.3974, 0.0558, -0.3905, ..., -0.2893, -0.1375, 0.3952],\n",
+ " [-0.6964, -0.0369, 0.2051, ..., -0.4537, 0.1505, 0.5892],\n",
+ " [-0.6722, 1.0108, -0.7013, ..., -0.6308, -0.2771, 0.3940],\n",
+ " ...,\n",
+ " [-0.5778, 0.5753, -0.5293, ..., -0.7302, -0.5109, 1.3849],\n",
+ " [ 0.5438, 0.0137, -0.3779, ..., 0.1812, -0.6194, -0.1336],\n",
+ " [-0.4519, -0.3448, -1.0264, ..., -0.1259, -0.4856, 0.3235]]])"
+ ]
+ },
+ "execution_count": 40,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "output['hidden_states'][-1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. from scratch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:02:16.279415Z",
+ "start_time": "2022-10-23T12:02:16.213720Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n",
+ " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n",
+ " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n",
+ " ...,\n",
+ " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n",
+ " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n",
+ " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]],\n",
+ " grad_fn=<AddBackward0>)"
+ ]
+ },
+ "execution_count": 41,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mlm.cls(output['hidden_states'][-1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:02:24.413213Z",
+ "start_time": "2022-10-23T12:02:24.408190Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n",
+ " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n",
+ " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n",
+ " ...,\n",
+ " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n",
+ " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n",
+ " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])"
+ ]
+ },
+ "execution_count": 42,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "output.logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:02:46.240782Z",
+ "start_time": "2022-10-23T12:02:46.238325Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "last_hidden_state = output['hidden_states'][-1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:04:31.581726Z",
+ "start_time": "2022-10-23T12:04:31.577795Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 62, 768])"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "last_hidden_state.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:06:02.323620Z",
+ "start_time": "2022-10-23T12:06:02.246346Z"
+ },
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 62, 768])\n",
+ "torch.Size([1, 62, 30522])\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n",
+ " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n",
+ " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n",
+ " ...,\n",
+ " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n",
+ " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n",
+ " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])"
+ ]
+ },
+ "execution_count": 46,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "mlm.eval()\n",
+ "with torch.no_grad():\n",
+ " transformed = mlm.cls.predictions.transform(last_hidden_state)\n",
+ " print(transformed.shape)\n",
+ " logits = mlm.cls.predictions.decoder(transformed)\n",
+ " print(logits.shape)\n",
+ "logits"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:08:21.847190Z",
+ "start_time": "2022-10-23T12:08:21.842330Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.5636)"
+ ]
+ },
+ "execution_count": 47,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "output.loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. loss and translate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:08:26.111715Z",
+ "start_time": "2022-10-23T12:08:26.108950Z"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "ce = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:08:30.584319Z",
+ "start_time": "2022-10-23T12:08:30.579982Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 62, 30522])"
+ ]
+ },
+ "execution_count": 49,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "logits.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:08:32.011230Z",
+ "start_time": "2022-10-23T12:08:32.007146Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 62])"
+ ]
+ },
+ "execution_count": 50,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs['labels'].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:08:38.754098Z",
+ "start_time": "2022-10-23T12:08:38.748962Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([62])"
+ ]
+ },
+ "execution_count": 51,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs['labels'][0].view(-1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:08:43.837921Z",
+ "start_time": "2022-10-23T12:08:43.826163Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.5636)"
+ ]
+ },
+ "execution_count": 52,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ce(logits[0], inputs['labels'][0].view(-1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:45:20.920638Z",
+ "start_time": "2022-10-23T11:45:20.914885Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 1012, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n",
+ " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2035, 1996, 2698,\n",
+ " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2586, 2000, 3693,\n",
+ " 1996, 18179, 1012, 2162, 3631, 2034, 1999, 2258, 6863, 2043,\n",
+ " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n",
+ " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n",
+ " 1012, 3519])"
+ ]
+ },
+ "execution_count": 89,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.argmax(logits[0], dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T11:45:59.718848Z",
+ "start_time": "2022-10-23T11:45:59.714203Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"[CLS] after abraham lincoln won the november 1860 [MASK] election [MASK] an anti - slavery platform , [MASK] [MASK] seven slave [MASK] declared their secession from the [MASK] to [MASK] the confederacy . war broke out [MASK] april 1861 when [MASK] ##ist forces attacked fort sum [MASK] in south [MASK] , just over a month after lincoln ' s [MASK] . [SEP]\""
+ ]
+ },
+ "execution_count": 92,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:09:36.590470Z",
+ "start_time": "2022-10-23T12:09:36.584581Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\". after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in december 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . s\""
+ ]
+ },
+ "execution_count": 53,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0], dim=1)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2022-10-23T12:09:47.275769Z",
+ "start_time": "2022-10-23T12:09:47.271126Z"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\""
+ ]
+ },
+ "execution_count": 54,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/fine_tune/bert/tutorials/09_mlm.py b/fine_tune/bert/tutorials/09_mlm.py
new file mode 100644
index 0000000..0177c3f
--- /dev/null
+++ b/fine_tune/bert/tutorials/09_mlm.py
@@ -0,0 +1,33 @@
+
+import torch
+from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM
+
+
+model_type = 'bert-base-uncased'
+
+tokenizer = BertTokenizer.from_pretrained(model_type)
+bert = BertModel.from_pretrained(model_type)
+mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)
+
+
+text = ("After Abraham Lincoln won the November 1860 presidential "
+ "election on an anti-slavery platform, an initial seven "
+ "slave states declared their secession from the country "
+ "to form the Confederacy. War broke out in April 1861 "
+ "when secessionist forces attacked Fort Sumter in South "
+ "Carolina, just over a month after Lincoln's "
+ "inauguration.")
+
+inputs = tokenizer(text, return_tensors='pt')
+inputs['labels'] = inputs['input_ids'].detach().clone()
+
+mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
+ * (inputs['input_ids'] != 101) \
+ * (inputs['input_ids'] != 102)
+selection = torch.flatten(mask_arr[0].nonzero()).tolist()
+inputs['input_ids'][0, selection] = 103
+
+mlm.eval()
+with torch.no_grad():
+ output = mlm(**inputs)
+print() \ No newline at end of file