In [1]:
import torch
from torch import nn
from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM

## 1. model load and data preprocessing

In [2]:
model_type = 'bert-base-uncased'

In [3]:
tokenizer = BertTokenizer.from_pretrained(model_type)
bert = BertModel.from_pretrained(model_type)
mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relatio

In [4]:
bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [5]:
mlm

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [6]:
text = ("After Abraham Lincoln won the November 1860 presidential "
        "election on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")

In [7]:
text

"After Abraham Lincoln won the November 1860 presidential election on an anti-slavery platform, an initial seven slave states declared their secession from the country to form the Confederacy. War broke out in April 1861 when secessionist forces attacked Fort Sumter in South Carolina, just over a month after Lincoln's inauguration."

In [25]:
inputs = tokenizer(text, return_tensors='pt')

In [8]:
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [20]:
inputs['input_ids'].shape

torch.Size([1, 62])

In [12]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

## 2. masking

In [26]:
inputs['labels'] = inputs['input_ids'].detach().clone()

In [14]:
inputs['labels']

tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,  2041,  1999,  2258,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
          1012,   102]])

In [9]:
mask = torch.rand(inputs['input_ids'].shape) < 0.15

In [17]:
mask

tensor([[False, False,  True, False, False, False, False, False, False,  True,
          True,  True, False, False, False, False, False,  True, False, False,
         False, False, False, False, False, False, False, False,  True, False,
         False, False,  True, False, False, False, False, False, False,  True,
         False, False, False, False,  True, False, False, False, False, False,
         False, False,  True, False, False, False, False, False, False, False,
          True, False]])

In [18]:
sum(mask[0])

tensor(11)

In [21]:
11/62

0.1774193548387097

In [27]:
mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
        * (inputs['input_ids'] != 101) \
        * (inputs['input_ids'] != 102)

In [11]:
mask_arr

tensor([[False, False, False, False, False, False, False,  True, False, False,
         False, False, False, False, False, False, False, False,  True, False,
          True, False, False, False,  True, False, False, False, False, False,
         False, False, False,  True, False, False, False, False, False,  True,
          True, False, False, False, False, False, False,  True, False, False,
          True, False,  True, False, False, False,  True, False, False, False,
         False, False]])

In [12]:
sum(mask_arr[0])

tensor(11)

In [28]:
selection = torch.flatten(mask_arr[0].nonzero()).tolist()

In [14]:
selection

[7, 18, 20, 24, 33, 39, 40, 47, 50, 52, 56]

In [15]:
tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [16]:
tokenizer.vocab['[MASK]']

103

In [29]:
inputs['input_ids'][0, selection] = 103

In [30]:
inputs

{'input_ids': tensor([[  101,  2044,  8181,  5367,  2180,  1996,  2281,   103,  4883,  2602,
          2006,  2019,  3424,  1011,  8864,  4132,  1010,  2019,  3988,  2698,
          6658,  2163,  4161,  2037, 22965,  2013,  1996,  2406,  2000,  2433,
          1996, 18179,  1012,  2162,  3631,   103,  1999,   103,  6863,  2043,
         22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
          1010,   103,  2058,  1037,  3204,  2044,  5367,   103,  1055, 17331,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([

In [31]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

'[CLS] after abraham lincoln won the november [MASK] presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke [MASK] in [MASK] 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , [MASK] over a month after lincoln [MASK] s inauguration . [SEP]'

In [32]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"

## 3. forward and calculate loss

In [34]:
mlm.eval()
with torch.no_grad():
    output = mlm(**inputs)

In [35]:
output.keys()

odict_keys(['loss', 'logits', 'hidden_states'])

In [36]:
output.logits

tensor([[[ -7.4276,  -7.3447,  -7.4013,  ...,  -6.5880,  -6.5265,  -4.6203],
         [-11.9590, -11.8243, -12.0316,  ..., -11.1553, -10.6757,  -8.8617],
         [ -5.8531,  -6.0324,  -5.5261,  ...,  -5.8973,  -5.6533,  -4.9494],
         ...,
         [ -4.3205,  -4.3884,  -4.2894,  ...,  -3.0957,  -2.8461,  -8.2620],
         [-14.4766, -14.4744, -14.4897,  ..., -11.6094, -11.7776,  -9.8620],
         [-11.2059, -11.7678, -11.3313,  ..., -10.9919,  -8.8702,  -9.4242]]])

In [37]:
output.loss

tensor(0.5636)

In [39]:
len(output['hidden_states'])

13

In [40]:
output['hidden_states'][-1]

tensor([[[-0.3974,  0.0558, -0.3905,  ..., -0.2893, -0.1375,  0.3952],
         [-0.6964, -0.0369,  0.2051,  ..., -0.4537,  0.1505,  0.5892],
         [-0.6722,  1.0108, -0.7013,  ..., -0.6308, -0.2771,  0.3940],
         ...,
         [-0.5778,  0.5753, -0.5293,  ..., -0.7302, -0.5109,  1.3849],
         [ 0.5438,  0.0137, -0.3779,  ...,  0.1812, -0.6194, -0.1336],
         [-0.4519, -0.3448, -1.0264,  ..., -0.1259, -0.4856,  0.3235]]])

## 4. from scratch

In [41]:
mlm.cls(output['hidden_states'][-1])

tensor([[[ -7.4276,  -7.3447,  -7.4013,  ...,  -6.5880,  -6.5265,  -4.6203],
         [-11.9590, -11.8243, -12.0316,  ..., -11.1553, -10.6757,  -8.8617],
         [ -5.8531,  -6.0324,  -5.5261,  ...,  -5.8973,  -5.6533,  -4.9494],
         ...,
         [ -4.3205,  -4.3884,  -4.2894,  ...,  -3.0957,  -2.8461,  -8.2620],
         [-14.4766, -14.4744, -14.4897,  ..., -11.6094, -11.7776,  -9.8620],
         [-11.2059, -11.7678, -11.3313,  ..., -10.9919,  -8.8702,  -9.4242]]],
       grad_fn=<AddBackward0>)

In [42]:
output.logits

tensor([[[ -7.4276,  -7.3447,  -7.4013,  ...,  -6.5880,  -6.5265,  -4.6203],
         [-11.9590, -11.8243, -12.0316,  ..., -11.1553, -10.6757,  -8.8617],
         [ -5.8531,  -6.0324,  -5.5261,  ...,  -5.8973,  -5.6533,  -4.9494],
         ...,
         [ -4.3205,  -4.3884,  -4.2894,  ...,  -3.0957,  -2.8461,  -8.2620],
         [-14.4766, -14.4744, -14.4897,  ..., -11.6094, -11.7776,  -9.8620],
         [-11.2059, -11.7678, -11.3313,  ..., -10.9919,  -8.8702,  -9.4242]]])

In [43]:
last_hidden_state = output['hidden_states'][-1]

In [44]:
last_hidden_state.shape

torch.Size([1, 62, 768])

In [46]:
mlm.eval()
with torch.no_grad():
    transformed = mlm.cls.predictions.transform(last_hidden_state)
    print(transformed.shape)
    logits = mlm.cls.predictions.decoder(transformed)
    print(logits.shape)
logits

torch.Size([1, 62, 768])
torch.Size([1, 62, 30522])


tensor([[[ -7.4276,  -7.3447,  -7.4013,  ...,  -6.5880,  -6.5265,  -4.6203],
         [-11.9590, -11.8243, -12.0316,  ..., -11.1553, -10.6757,  -8.8617],
         [ -5.8531,  -6.0324,  -5.5261,  ...,  -5.8973,  -5.6533,  -4.9494],
         ...,
         [ -4.3205,  -4.3884,  -4.2894,  ...,  -3.0957,  -2.8461,  -8.2620],
         [-14.4766, -14.4744, -14.4897,  ..., -11.6094, -11.7776,  -9.8620],
         [-11.2059, -11.7678, -11.3313,  ..., -10.9919,  -8.8702,  -9.4242]]])

In [47]:
output.loss

tensor(0.5636)

## 5. loss and translate

In [48]:
ce = nn.CrossEntropyLoss()

In [49]:
logits.shape

torch.Size([1, 62, 30522])

In [50]:
inputs['labels'].shape

torch.Size([1, 62])

In [51]:
inputs['labels'][0].view(-1).shape

torch.Size([62])

In [52]:
ce(logits[0], inputs['labels'][0].view(-1))

tensor(0.5636)

In [89]:
torch.argmax(logits[0], dim=1)

tensor([ 1012,  2044,  8181,  5367,  2180,  1996,  2281,  7313,  4883,  2602,
         2006,  2019,  3424,  1011,  8864,  4132,  1010,  2035,  1996,  2698,
         6658,  2163,  4161,  2037, 22965,  2013,  1996,  2586,  2000,  3693,
         1996, 18179,  1012,  2162,  3631,  2034,  1999,  2258,  6863,  2043,
        22965,  2923,  2749,  4457,  3481,  7680,  3334,  1999,  2148,  3792,
         1010,  2074,  2058,  1037,  3204,  2044,  5367,  1005,  1055, 17331,
         1012,  3519])

In [92]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

"[CLS] after abraham lincoln won the november 1860 [MASK] election [MASK] an anti - slavery platform , [MASK] [MASK] seven slave [MASK] declared their secession from the [MASK] to [MASK] the confederacy . war broke out [MASK] april 1861 when [MASK] ##ist forces attacked fort sum [MASK] in south [MASK] , just over a month after lincoln ' s [MASK] . [SEP]"

In [53]:
' '.join(tokenizer.convert_ids_to_tokens(torch.argmax(logits[0], dim=1)))

". after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in december 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . s"

In [54]:
' '.join(tokenizer.convert_ids_to_tokens(inputs['labels'][0]))

"[CLS] after abraham lincoln won the november 1860 presidential election on an anti - slavery platform , an initial seven slave states declared their secession from the country to form the confederacy . war broke out in april 1861 when secession ##ist forces attacked fort sum ##ter in south carolina , just over a month after lincoln ' s inauguration . [SEP]"