summaryrefslogtreecommitdiff
path: root/learn_torch/bert
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-06-19 09:20:25 +0800
committerzhang <zch921005@126.com>2022-06-19 09:20:25 +0800
commited026519d959ecc60a895f379c228de5df77ffb0 (patch)
treec2fc265144824c4824814c003c6e0403ebaa2d8c /learn_torch/bert
parent2a42272daeab92ab26481745776dc51ed144924f (diff)
daily update
Diffstat (limited to 'learn_torch/bert')
-rw-r--r--learn_torch/bert/fill_mask.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/learn_torch/bert/fill_mask.py b/learn_torch/bert/fill_mask.py
new file mode 100644
index 0000000..24e177f
--- /dev/null
+++ b/learn_torch/bert/fill_mask.py
@@ -0,0 +1,28 @@
+
+import torch
+from datasets import load_dataset
+from transformers import BertTokenizer
+
+
+#定义数据集
+class Dataset(torch.utils.data.Dataset):
+ def __init__(self, split):
+ dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)
+
+ def f(data):
+ return len(data['text']) > 30
+
+ self.dataset = dataset.filter(f)
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, i):
+ text = self.dataset[i]['text']
+
+ return text
+
+if __name__ == '__main__':
+ dataset = Dataset('train')
+ print(len(dataset), dataset[0])
+ tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')