summaryrefslogtreecommitdiff
path: root/fine_tune/bert/tutorials/09_mlm.py
blob: 0177c3f76d9f92cdd21116faa91f49f0a1a8fb79 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

import torch
from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM


model_type = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(model_type)
bert = BertModel.from_pretrained(model_type)
mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)


text = ("After Abraham Lincoln won the November 1860 presidential "
        "election on an anti-slavery platform, an initial seven "
        "slave states declared their secession from the country "
        "to form the Confederacy. War broke out in April 1861 "
        "when secessionist forces attacked Fort Sumter in South "
        "Carolina, just over a month after Lincoln's "
        "inauguration.")

inputs = tokenizer(text, return_tensors='pt')
inputs['labels'] = inputs['input_ids'].detach().clone()

mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
        * (inputs['input_ids'] != 101) \
        * (inputs['input_ids'] != 102)
selection = torch.flatten(mask_arr[0].nonzero()).tolist()
inputs['input_ids'][0, selection] = 103

mlm.eval()
with torch.no_grad():
    output = mlm(**inputs)
print()