summaryrefslogtreecommitdiff
path: root/learn_torch
diff options
context:
space:
mode:
authorzhang <zch921005@126.com>2022-06-19 09:20:25 +0800
committerzhang <zch921005@126.com>2022-06-19 09:20:25 +0800
commited026519d959ecc60a895f379c228de5df77ffb0 (patch)
treec2fc265144824c4824814c003c6e0403ebaa2d8c /learn_torch
parent2a42272daeab92ab26481745776dc51ed144924f (diff)
daily update
Diffstat (limited to 'learn_torch')
-rw-r--r--learn_torch/bert/fill_mask.py28
-rw-r--r--learn_torch/seq2seq/base_model.py51
2 files changed, 79 insertions, 0 deletions
diff --git a/learn_torch/bert/fill_mask.py b/learn_torch/bert/fill_mask.py
new file mode 100644
index 0000000..24e177f
--- /dev/null
+++ b/learn_torch/bert/fill_mask.py
@@ -0,0 +1,28 @@
+
+import torch
+from datasets import load_dataset
+from transformers import BertTokenizer
+
+
+#定义数据集
+class Dataset(torch.utils.data.Dataset):
+ def __init__(self, split):
+ dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)
+
+ def f(data):
+ return len(data['text']) > 30
+
+ self.dataset = dataset.filter(f)
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, i):
+ text = self.dataset[i]['text']
+
+ return text
+
+if __name__ == '__main__':
+ dataset = Dataset('train')
+ print(len(dataset), dataset[0])
+ tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
diff --git a/learn_torch/seq2seq/base_model.py b/learn_torch/seq2seq/base_model.py
new file mode 100644
index 0000000..bc292cf
--- /dev/null
+++ b/learn_torch/seq2seq/base_model.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from torchtext.legacy.datasets import Multi30k
+from torchtext.legacy.data import Field, BucketIterator
+
+import spacy
+import numpy as np
+
+import random
+import math
+import time
+
+
+SEED = 1234
+
+random.seed(SEED)
+np.random.seed(SEED)
+torch.manual_seed(SEED)
+torch.cuda.manual_seed(SEED)
+torch.backends.cudnn.deterministic = True
+
+spacy_de = spacy.load('de_core_news_sm')
+spacy_en = spacy.load('en_core_web_sm')
+
+def tokenize_de(text):
+ """
+ Tokenizes German text from a string into a list of strings (tokens) and reverses it
+ """
+ return [tok.text for tok in spacy_de.tokenizer(text)][::-1]
+
+def tokenize_en(text):
+ """
+ Tokenizes English text from a string into a list of strings (tokens)
+ """
+ return [tok.text for tok in spacy_en.tokenizer(text)]
+
+SRC = Field(tokenize = tokenize_de,
+ init_token = '<sos>',
+ eos_token = '<eos>',
+ lower = True)
+
+TRG = Field(tokenize = tokenize_en,
+ init_token = '<sos>',
+ eos_token = '<eos>',
+ lower = True)
+
+train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'),
+ fields = (SRC, TRG))
+