summaryrefslogtreecommitdiff
path: root/fine_tune/bert/tutorials/09_mlm.py
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-10-23 20:21:40 +0800
committerzhang <zch921005@126.com>2022-10-23 20:21:40 +0800
commitb5c7383a7abfa87396bf585b789c9e0474e12652 (patch)
treec86ad7f1d02e4eb3a349d0168695c271f3fbee11 /fine_tune/bert/tutorials/09_mlm.py
parent3c2d5d232372b8f917b67be94551c8faf0754cb7 (diff)
masked language model
Diffstat (limited to 'fine_tune/bert/tutorials/09_mlm.py')
-rw-r--r--fine_tune/bert/tutorials/09_mlm.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/fine_tune/bert/tutorials/09_mlm.py b/fine_tune/bert/tutorials/09_mlm.py
new file mode 100644
index 0000000..0177c3f
--- /dev/null
+++ b/fine_tune/bert/tutorials/09_mlm.py
@@ -0,0 +1,33 @@
+
+import torch
+from transformers.models.bert import BertModel, BertTokenizer, BertForMaskedLM
+
+
+model_type = 'bert-base-uncased'
+
+tokenizer = BertTokenizer.from_pretrained(model_type)
+bert = BertModel.from_pretrained(model_type)
+mlm = BertForMaskedLM.from_pretrained(model_type, output_hidden_states=True)
+
+
+text = ("After Abraham Lincoln won the November 1860 presidential "
+ "election on an anti-slavery platform, an initial seven "
+ "slave states declared their secession from the country "
+ "to form the Confederacy. War broke out in April 1861 "
+ "when secessionist forces attacked Fort Sumter in South "
+ "Carolina, just over a month after Lincoln's "
+ "inauguration.")
+
+inputs = tokenizer(text, return_tensors='pt')
+inputs['labels'] = inputs['input_ids'].detach().clone()
+
+mask_arr = (torch.rand(inputs['input_ids'].shape) < 0.15) \
+ * (inputs['input_ids'] != 101) \
+ * (inputs['input_ids'] != 102)
+selection = torch.flatten(mask_arr[0].nonzero()).tolist()
+inputs['input_ids'][0, selection] = 103
+
+mlm.eval()
+with torch.no_grad():
+ output = mlm(**inputs)
+print() \ No newline at end of file