diff options
| author | zhang <zch921005@126.com> | 2022-06-19 09:20:25 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-06-19 09:20:25 +0800 |
| commit | ed026519d959ecc60a895f379c228de5df77ffb0 (patch) | |
| tree | c2fc265144824c4824814c003c6e0403ebaa2d8c /learn_torch/bert/fill_mask.py | |
| parent | 2a42272daeab92ab26481745776dc51ed144924f (diff) | |
daily update
Diffstat (limited to 'learn_torch/bert/fill_mask.py')
| -rw-r--r-- | learn_torch/bert/fill_mask.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/learn_torch/bert/fill_mask.py b/learn_torch/bert/fill_mask.py new file mode 100644 index 0000000..24e177f --- /dev/null +++ b/learn_torch/bert/fill_mask.py @@ -0,0 +1,28 @@ + +import torch +from datasets import load_dataset +from transformers import BertTokenizer + + +#定义数据集 +class Dataset(torch.utils.data.Dataset): + def __init__(self, split): + dataset = load_dataset(path='seamew/ChnSentiCorp', split=split) + + def f(data): + return len(data['text']) > 30 + + self.dataset = dataset.filter(f) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, i): + text = self.dataset[i]['text'] + + return text + +if __name__ == '__main__': + dataset = Dataset('train') + print(len(dataset), dataset[0]) + tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') |
