diff options
Diffstat (limited to 'learn_torch/seq')
| -rw-r--r-- | learn_torch/seq/char_rnn.py | 66 |
1 files changed, 64 insertions, 2 deletions
diff --git a/learn_torch/seq/char_rnn.py b/learn_torch/seq/char_rnn.py index 82cc924..552266b 100644 --- a/learn_torch/seq/char_rnn.py +++ b/learn_torch/seq/char_rnn.py @@ -4,6 +4,8 @@ import glob import os import unicodedata import string +import torch +from torch import nn def find_files(path): @@ -31,11 +33,71 @@ def build_vocab(filepath): return category_lines, all_categories +def letter_to_index(letter): + return all_letters.index(letter) + + +def letter_to_tensor(letter): + tensor = torch.zeros(1, n_letters) + tensor[0][letter_to_index(letter)] = 1 + return tensor + + +def line_to_tensor(line): + tenor = torch.zeros(len(line), 1, n_letters) + for i, letter in enumerate(line): + tenor[i][0] = letter_to_tensor(letter) + return tenor + + +class RNN(torch.nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(RNN, self).__init__() + self.hidden_size = hidden_size + self.i2h = torch.nn.Linear(input_size + hidden_size, hidden_size) + self.i2o = torch.nn.Linear(input_size + hidden_size, output_size) + self.softmax = torch.nn.LogSoftmax(dim=1) + + def forward(self, input, hidden): + combined = torch.cat((input, hidden), 1) + hidden = self.i2h(combined) + output = self.i2o(combined) + output = self.softmax(output) + return output, hidden + + def init_hidden(self): + return torch.zeros(1, self.hidden_size) + + +def category_from_output(output): + top_v, top_i = output.topk(1) + category_i = top_i[0].item() + return category_i, all_categories[category_i], top_v.item() + + +def train(x, y): + hidden = rnn.init_hidden() + rnn.zero_grad() + for i in range(x.shape[0]): + output, hidden = rnn.forward(x[i], hidden) + loss = criterion(output, y) + loss.backward() + for p in rnn.parameters(): + p.data.add_(p.grad.data, alpha=-lr) + return output, loss.item() + if __name__ == '__main__': all_letters = string.ascii_letters + " .,;'" n_letters = len(all_letters) - print('n_letters = {}'.format(n_letters)) + category_lines, all_categories = build_vocab('../text_data/names/*.txt') - print(all_categories) + + n_categories = len(all_categories) + n_hidden = 128 + rnn = RNN(n_letters, n_hidden, n_categories) + lr = 1e-5 + criterion = torch.nn.NLLLoss() + + nn.CrossEntropyLoss() |
