diff options
| author | zhang <zch921005@126.com> | 2022-06-19 09:20:25 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-06-19 09:20:25 +0800 |
| commit | ed026519d959ecc60a895f379c228de5df77ffb0 (patch) | |
| tree | c2fc265144824c4824814c003c6e0403ebaa2d8c /learn_torch/seq2seq/base_model.py | |
| parent | 2a42272daeab92ab26481745776dc51ed144924f (diff) | |
daily update
Diffstat (limited to 'learn_torch/seq2seq/base_model.py')
| -rw-r--r-- | learn_torch/seq2seq/base_model.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/learn_torch/seq2seq/base_model.py b/learn_torch/seq2seq/base_model.py new file mode 100644 index 0000000..bc292cf --- /dev/null +++ b/learn_torch/seq2seq/base_model.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +import torch.optim as optim + +from torchtext.legacy.datasets import Multi30k +from torchtext.legacy.data import Field, BucketIterator + +import spacy +import numpy as np + +import random +import math +import time + + +SEED = 1234 + +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) +torch.cuda.manual_seed(SEED) +torch.backends.cudnn.deterministic = True + +spacy_de = spacy.load('de_core_news_sm') +spacy_en = spacy.load('en_core_web_sm') + +def tokenize_de(text): + """ + Tokenizes German text from a string into a list of strings (tokens) and reverses it + """ + return [tok.text for tok in spacy_de.tokenizer(text)][::-1] + +def tokenize_en(text): + """ + Tokenizes English text from a string into a list of strings (tokens) + """ + return [tok.text for tok in spacy_en.tokenizer(text)] + +SRC = Field(tokenize = tokenize_de, + init_token = '<sos>', + eos_token = '<eos>', + lower = True) + +TRG = Field(tokenize = tokenize_en, + init_token = '<sos>', + eos_token = '<eos>', + lower = True) + +train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), + fields = (SRC, TRG)) + |
