summaryrefslogtreecommitdiff
path: root/learn_torch/seq2seq/base_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'learn_torch/seq2seq/base_model.py')
-rw-r--r--learn_torch/seq2seq/base_model.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/learn_torch/seq2seq/base_model.py b/learn_torch/seq2seq/base_model.py
new file mode 100644
index 0000000..bc292cf
--- /dev/null
+++ b/learn_torch/seq2seq/base_model.py
@@ -0,0 +1,51 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from torchtext.legacy.datasets import Multi30k
+from torchtext.legacy.data import Field, BucketIterator
+
+import spacy
+import numpy as np
+
+import random
+import math
+import time
+
+
+SEED = 1234
+
+random.seed(SEED)
+np.random.seed(SEED)
+torch.manual_seed(SEED)
+torch.cuda.manual_seed(SEED)
+torch.backends.cudnn.deterministic = True
+
+spacy_de = spacy.load('de_core_news_sm')
+spacy_en = spacy.load('en_core_web_sm')
+
+def tokenize_de(text):
+ """
+ Tokenizes German text from a string into a list of strings (tokens) and reverses it
+ """
+ return [tok.text for tok in spacy_de.tokenizer(text)][::-1]
+
+def tokenize_en(text):
+ """
+ Tokenizes English text from a string into a list of strings (tokens)
+ """
+ return [tok.text for tok in spacy_en.tokenizer(text)]
+
+SRC = Field(tokenize = tokenize_de,
+ init_token = '<sos>',
+ eos_token = '<eos>',
+ lower = True)
+
+TRG = Field(tokenize = tokenize_en,
+ init_token = '<sos>',
+ eos_token = '<eos>',
+ lower = True)
+
+train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'),
+ fields = (SRC, TRG))
+