summaryrefslogtreecommitdiff
path: root/learn_torch/bert/fill_mask.py
blob: 24e177f40c6852b549e42000dc342dd17bfb16ba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

import torch
from datasets import load_dataset
from transformers import BertTokenizer


#定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

        def f(data):
            return len(data['text']) > 30

        self.dataset = dataset.filter(f)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']

        return text

if __name__ == '__main__':
    dataset = Dataset('train')
    print(len(dataset), dataset[0])
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')