summaryrefslogtreecommitdiff
path: root/learn_torch
diff options
context:
space:
mode:
Diffstat (limited to 'learn_torch')
-rw-r--r--learn_torch/learn_nn/custom_module.py21
-rw-r--r--learn_torch/seq/char_rnn.py66
2 files changed, 85 insertions, 2 deletions
diff --git a/learn_torch/learn_nn/custom_module.py b/learn_torch/learn_nn/custom_module.py
new file mode 100644
index 0000000..7052b14
--- /dev/null
+++ b/learn_torch/learn_nn/custom_module.py
@@ -0,0 +1,21 @@
+import torch
+from torch import nn
+
+
+class MySeq(torch.nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ for block in args:
+ self._modules[block] = block
+
+ def forward(self, X):
+ for block in self._modules.values():
+ X = block(X)
+ return X
+
+
+if __name__ == '__main__':
+ X = torch.rand(2, 20)
+ net = MySeq(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
+ net(X)
+
diff --git a/learn_torch/seq/char_rnn.py b/learn_torch/seq/char_rnn.py
index 82cc924..552266b 100644
--- a/learn_torch/seq/char_rnn.py
+++ b/learn_torch/seq/char_rnn.py
@@ -4,6 +4,8 @@ import glob
import os
import unicodedata
import string
+import torch
+from torch import nn
def find_files(path):
@@ -31,11 +33,71 @@ def build_vocab(filepath):
return category_lines, all_categories
+def letter_to_index(letter):
+ return all_letters.index(letter)
+
+
+def letter_to_tensor(letter):
+ tensor = torch.zeros(1, n_letters)
+ tensor[0][letter_to_index(letter)] = 1
+ return tensor
+
+
+def line_to_tensor(line):
+ tenor = torch.zeros(len(line), 1, n_letters)
+ for i, letter in enumerate(line):
+ tenor[i][0] = letter_to_tensor(letter)
+ return tenor
+
+
+class RNN(torch.nn.Module):
+ def __init__(self, input_size, hidden_size, output_size):
+ super(RNN, self).__init__()
+ self.hidden_size = hidden_size
+ self.i2h = torch.nn.Linear(input_size + hidden_size, hidden_size)
+ self.i2o = torch.nn.Linear(input_size + hidden_size, output_size)
+ self.softmax = torch.nn.LogSoftmax(dim=1)
+
+ def forward(self, input, hidden):
+ combined = torch.cat((input, hidden), 1)
+ hidden = self.i2h(combined)
+ output = self.i2o(combined)
+ output = self.softmax(output)
+ return output, hidden
+
+ def init_hidden(self):
+ return torch.zeros(1, self.hidden_size)
+
+
+def category_from_output(output):
+ top_v, top_i = output.topk(1)
+ category_i = top_i[0].item()
+ return category_i, all_categories[category_i], top_v.item()
+
+
+def train(x, y):
+ hidden = rnn.init_hidden()
+ rnn.zero_grad()
+ for i in range(x.shape[0]):
+ output, hidden = rnn.forward(x[i], hidden)
+ loss = criterion(output, y)
+ loss.backward()
+ for p in rnn.parameters():
+ p.data.add_(p.grad.data, alpha=-lr)
+ return output, loss.item()
+
if __name__ == '__main__':
all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)
- print('n_letters = {}'.format(n_letters))
+
category_lines, all_categories = build_vocab('../text_data/names/*.txt')
- print(all_categories)
+
+ n_categories = len(all_categories)
+ n_hidden = 128
+ rnn = RNN(n_letters, n_hidden, n_categories)
+ lr = 1e-5
+ criterion = torch.nn.NLLLoss()
+
+ nn.CrossEntropyLoss()