{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:50:16.399172Z", "start_time": "2022-10-23T11:50:07.776148Z" } }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. model load and data preprocessing" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:50:21.521103Z", "start_time": "2022-10-23T11:50:21.517814Z" } }, "outputs": [], "source": [ "model_type = 'bert-base-uncased'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:50:42.915776Z", "start_time": "2022-10-23T11:50:24.836326Z" }, "collapsed": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "tokenizer = BertTokenizer.from_pretrained(model_type)\n", "bert = BertModel.from_pretrained(model_type)\n", "mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:50:48.950455Z", "start_time": "2022-10-23T11:50:48.938428Z" } }, "outputs": [ { "data": { "text/plain": [ "BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (1): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (3): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (4): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (5): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (6): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (7): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (8): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (9): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (10): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (11): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " (pooler): BertPooler(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (activation): Tanh()\n", " )\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bert" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:50:53.507897Z", "start_time": "2022-10-23T11:50:53.500990Z" } }, "outputs": [ { "data": { "text/plain": [ "BertForMaskedLM(\n", " (bert): BertModel(\n", " (embeddings): BertEmbeddings(\n", " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n", " (position_embeddings): Embedding(512, 768)\n", " (token_type_embeddings): Embedding(2, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): BertEncoder(\n", " (layer): ModuleList(\n", " (0): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (1): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (2): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (3): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (4): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (5): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (6): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (7): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (8): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (9): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (10): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (11): BertLayer(\n", " (attention): BertAttention(\n", " (self): BertSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): BertSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): BertIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " )\n", " (output): BertOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (cls): BertOnlyMLMHead(\n", " (predictions): BertLMPredictionHead(\n", " (transform): BertPredictionHeadTransform(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (decoder): Linear(in_features=768, out_features=30522, bias=True)\n", " )\n", " )\n", ")" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlm" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:52:15.652759Z", "start_time": "2022-10-23T11:52:15.650142Z" } }, "outputs": [], "source": [ "text = (\"After Abraham Lincoln won the November 1860 presidential \"\n", " \"election on an anti-slavery platform, an initial seven \"\n", " \"slave states declared their secession from the country \"\n", " \"to form the Confederacy. War broke out in April 1861 \"\n", " \"when secessionist forces attacked Fort Sumter in South \"\n", " \"Carolina, just over a month after Lincoln's \"\n", " \"inauguration.\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T10:39:03.903713Z", "start_time": "2022-10-23T10:39:03.894017Z" } }, "outputs": [ { "data": { "text/plain": [ "\"After Abraham Lincoln won the November 1860 presidential election on an anti-slavery platform, an initial seven slave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secessionist forces attacked Fort Sumter in South Carolina, just over a month after Lincoln's inauguration.\"" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:58:49.006117Z", "start_time": "2022-10-23T11:58:49.001261Z" } }, "outputs": [], "source": [ "inputs = tokenizer(text, return_tensors='pt')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:52:29.260747Z", "start_time": "2022-10-23T11:52:29.254577Z" }, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n", " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n", " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n", " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n", " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n", " 1012, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T10:45:03.577003Z", "start_time": "2022-10-23T10:45:03.573112Z" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 62])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs['input_ids'].shape" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T10:41:55.118829Z", "start_time": "2022-10-23T10:41:55.114199Z" } }, "outputs": [ { "data": { "text/plain": [ "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\"" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. masking" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:58:51.251735Z", "start_time": "2022-10-23T11:58:51.249225Z" } }, "outputs": [], "source": [ "inputs['labels'] = inputs['input_ids'].detach().clone()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T10:43:31.904097Z", "start_time": "2022-10-23T10:43:31.899369Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n", " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n", " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n", " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n", " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n", " 1012, 102]])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs['labels']" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:55:03.203643Z", "start_time": "2022-10-23T11:55:03.200563Z" } }, "outputs": [], "source": [ "mask = torch.rand(inputs['input_ids'].shape) < 0.15" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T10:44:20.474076Z", "start_time": "2022-10-23T10:44:20.469553Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[False, False, True, False, False, False, False, False, False, True,\n", " True, True, False, False, False, False, False, True, False, False,\n", " False, False, False, False, False, False, False, False, True, False,\n", " False, False, True, False, False, False, False, False, False, True,\n", " False, False, False, False, True, False, False, False, False, False,\n", " False, False, True, False, False, False, False, False, False, False,\n", " True, False]])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mask" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T10:44:29.746724Z", "start_time": "2022-10-23T10:44:29.741795Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(11)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(mask[0])" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T10:45:09.327613Z", "start_time": "2022-10-23T10:45:09.323922Z" } }, "outputs": [ { "data": { "text/plain": [ "0.1774193548387097" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "11/62" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:58:55.364837Z", "start_time": "2022-10-23T11:58:55.360049Z" } }, "outputs": [], "source": [ "mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \\\n", " * (inputs['input_ids'] != 101) \\\n", " * (inputs['input_ids'] != 102)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:55:21.753212Z", "start_time": "2022-10-23T11:55:21.748798Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[False, False, False, False, False, False, False, True, False, False,\n", " False, False, False, False, False, False, False, False, True, False,\n", " True, False, False, False, True, False, False, False, False, False,\n", " False, False, False, True, False, False, False, False, False, True,\n", " True, False, False, False, False, False, False, True, False, False,\n", " True, False, True, False, False, False, True, False, False, False,\n", " False, False]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mask_arr" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:55:30.201772Z", "start_time": "2022-10-23T11:55:30.197165Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(11)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum(mask_arr[0])" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:58:57.715475Z", "start_time": "2022-10-23T11:58:57.712828Z" } }, "outputs": [], "source": [ "selection = torch.flatten(mask_arr[0].nonzero()).tolist()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:55:40.122327Z", "start_time": "2022-10-23T11:55:40.118662Z" } }, "outputs": [ { "data": { "text/plain": [ "[7, 18, 20, 24, 33, 39, 40, 47, 50, 52, 56]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "selection" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:55:53.650978Z", "start_time": "2022-10-23T11:55:53.646182Z" } }, "outputs": [ { "data": { "text/plain": [ "{'unk_token': '[UNK]',\n", " 'sep_token': '[SEP]',\n", " 'pad_token': '[PAD]',\n", " 'cls_token': '[CLS]',\n", " 'mask_token': '[MASK]'}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.special_tokens_map" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:56:08.158223Z", "start_time": "2022-10-23T11:56:08.154330Z" } }, "outputs": [ { "data": { "text/plain": [ "103" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.vocab['[MASK]']" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:59:01.275756Z", "start_time": "2022-10-23T11:59:01.272924Z" } }, "outputs": [], "source": [ "inputs['input_ids'][0, selection] = 103" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:59:02.884794Z", "start_time": "2022-10-23T11:59:02.878431Z" } }, "outputs": [ { "data": { "text/plain": [ "{'input_ids': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 103, 4883, 2602,\n", " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n", " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n", " 1996, 18179, 1012, 2162, 3631, 103, 1999, 103, 6863, 2043,\n", " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", " 1010, 103, 2058, 1037, 3204, 2044, 5367, 103, 1055, 17331,\n", " 1012, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ 101, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n", " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2019, 3988, 2698,\n", " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2406, 2000, 2433,\n", " 1996, 18179, 1012, 2162, 3631, 2041, 1999, 2258, 6863, 2043,\n", " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n", " 1012, 102]])}" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:59:07.203048Z", "start_time": "2022-10-23T11:59:07.197865Z" } }, "outputs": [ { "data": { "text/plain": [ "'[CLS] after abraham lincoln won the november [MASK] presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke [MASK] in [MASK] 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , [MASK] over a month after lincoln [MASK] s inauguration . [SEP]'" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:59:09.922859Z", "start_time": "2022-10-23T11:59:09.918001Z" } }, "outputs": [ { "data": { "text/plain": [ "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\"" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. forward and calculate loss" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:00:19.169755Z", "start_time": "2022-10-23T12:00:18.879776Z" } }, "outputs": [], "source": [ "mlm.eval()\n", "with torch.no_grad():\n", " output = mlm(**inputs)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:00:26.763079Z", "start_time": "2022-10-23T12:00:26.759090Z" } }, "outputs": [ { "data": { "text/plain": [ "odict_keys(['loss', 'logits', 'hidden_states'])" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.keys()" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:00:33.102388Z", "start_time": "2022-10-23T12:00:33.096135Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n", " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n", " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n", " ...,\n", " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n", " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n", " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.logits" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:00:36.403750Z", "start_time": "2022-10-23T12:00:36.399118Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(0.5636)" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.loss" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:01:00.247788Z", "start_time": "2022-10-23T12:01:00.243721Z" } }, "outputs": [ { "data": { "text/plain": [ "13" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(output['hidden_states'])" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:01:58.138428Z", "start_time": "2022-10-23T12:01:58.132733Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[[-0.3974, 0.0558, -0.3905, ..., -0.2893, -0.1375, 0.3952],\n", " [-0.6964, -0.0369, 0.2051, ..., -0.4537, 0.1505, 0.5892],\n", " [-0.6722, 1.0108, -0.7013, ..., -0.6308, -0.2771, 0.3940],\n", " ...,\n", " [-0.5778, 0.5753, -0.5293, ..., -0.7302, -0.5109, 1.3849],\n", " [ 0.5438, 0.0137, -0.3779, ..., 0.1812, -0.6194, -0.1336],\n", " [-0.4519, -0.3448, -1.0264, ..., -0.1259, -0.4856, 0.3235]]])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output['hidden_states'][-1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. from scratch" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:02:16.279415Z", "start_time": "2022-10-23T12:02:16.213720Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n", " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n", " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n", " ...,\n", " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n", " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n", " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]],\n", " grad_fn=)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlm.cls(output['hidden_states'][-1])" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:02:24.413213Z", "start_time": "2022-10-23T12:02:24.408190Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n", " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n", " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n", " ...,\n", " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n", " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n", " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.logits" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:02:46.240782Z", "start_time": "2022-10-23T12:02:46.238325Z" } }, "outputs": [], "source": [ "last_hidden_state = output['hidden_states'][-1]" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:04:31.581726Z", "start_time": "2022-10-23T12:04:31.577795Z" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 62, 768])" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "last_hidden_state.shape" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:06:02.323620Z", "start_time": "2022-10-23T12:06:02.246346Z" }, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 62, 768])\n", "torch.Size([1, 62, 30522])\n" ] }, { "data": { "text/plain": [ "tensor([[[ -7.4276, -7.3447, -7.4013, ..., -6.5880, -6.5265, -4.6203],\n", " [-11.9590, -11.8243, -12.0316, ..., -11.1553, -10.6757, -8.8617],\n", " [ -5.8531, -6.0324, -5.5261, ..., -5.8973, -5.6533, -4.9494],\n", " ...,\n", " [ -4.3205, -4.3884, -4.2894, ..., -3.0957, -2.8461, -8.2620],\n", " [-14.4766, -14.4744, -14.4897, ..., -11.6094, -11.7776, -9.8620],\n", " [-11.2059, -11.7678, -11.3313, ..., -10.9919, -8.8702, -9.4242]]])" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlm.eval()\n", "with torch.no_grad():\n", " transformed = mlm.cls.predictions.transform(last_hidden_state)\n", " print(transformed.shape)\n", " logits = mlm.cls.predictions.decoder(transformed)\n", " print(logits.shape)\n", "logits" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:08:21.847190Z", "start_time": "2022-10-23T12:08:21.842330Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(0.5636)" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. loss and translate" ] }, { "cell_type": "code", "execution_count": 48, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:08:26.111715Z", "start_time": "2022-10-23T12:08:26.108950Z" } }, "outputs": [], "source": [ "ce = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 49, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:08:30.584319Z", "start_time": "2022-10-23T12:08:30.579982Z" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 62, 30522])" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "logits.shape" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:08:32.011230Z", "start_time": "2022-10-23T12:08:32.007146Z" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 62])" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs['labels'].shape" ] }, { "cell_type": "code", "execution_count": 51, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:08:38.754098Z", "start_time": "2022-10-23T12:08:38.748962Z" } }, "outputs": [ { "data": { "text/plain": [ "torch.Size([62])" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs['labels'][0].view(-1).shape" ] }, { "cell_type": "code", "execution_count": 52, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:08:43.837921Z", "start_time": "2022-10-23T12:08:43.826163Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor(0.5636)" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ce(logits[0], inputs['labels'][0].view(-1))" ] }, { "cell_type": "code", "execution_count": 89, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:45:20.920638Z", "start_time": "2022-10-23T11:45:20.914885Z" } }, "outputs": [ { "data": { "text/plain": [ "tensor([ 1012, 2044, 8181, 5367, 2180, 1996, 2281, 7313, 4883, 2602,\n", " 2006, 2019, 3424, 1011, 8864, 4132, 1010, 2035, 1996, 2698,\n", " 6658, 2163, 4161, 2037, 22965, 2013, 1996, 2586, 2000, 3693,\n", " 1996, 18179, 1012, 2162, 3631, 2034, 1999, 2258, 6863, 2043,\n", " 22965, 2923, 2749, 4457, 3481, 7680, 3334, 1999, 2148, 3792,\n", " 1010, 2074, 2058, 1037, 3204, 2044, 5367, 1005, 1055, 17331,\n", " 1012, 3519])" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.argmax(logits[0], dim=1)" ] }, { "cell_type": "code", "execution_count": 92, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T11:45:59.718848Z", "start_time": "2022-10-23T11:45:59.714203Z" } }, "outputs": [ { "data": { "text/plain": [ "\"[CLS] after abraham lincoln won the november 1860 [MASK] election [MASK] an anti - slavery platform , [MASK] [MASK] seven slave [MASK] declared their secession from the [MASK] to [MASK] the confederacy . war broke out [MASK] april 1861 when [MASK] ##ist forces attacked fort sum [MASK] in south [MASK] , just over a month after lincoln ' s [MASK] . [SEP]\"" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))" ] }, { "cell_type": "code", "execution_count": 53, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:09:36.590470Z", "start_time": "2022-10-23T12:09:36.584581Z" } }, "outputs": [ { "data": { "text/plain": [ "\". after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in december 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . s\"" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0], dim=1)))" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "ExecuteTime": { "end_time": "2022-10-23T12:09:47.275769Z", "start_time": "2022-10-23T12:09:47.271126Z" } }, "outputs": [ { "data": { "text/plain": [ "\"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]\"" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }