From 678fab50280b647d95213a9695d07c49542696f2 Mon Sep 17 00:00:00 2001 From: zhang Date: Sat, 21 May 2022 14:23:49 +0800 Subject: 0521 --- learn_torch/learn_nn/custom_module.py | 21 +++++++++++ learn_torch/seq/char_rnn.py | 66 +++++++++++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 learn_torch/learn_nn/custom_module.py (limited to 'learn_torch') diff --git a/learn_torch/learn_nn/custom_module.py b/learn_torch/learn_nn/custom_module.py new file mode 100644 index 0000000..7052b14 --- /dev/null +++ b/learn_torch/learn_nn/custom_module.py @@ -0,0 +1,21 @@ +import torch +from torch import nn + + +class MySeq(torch.nn.Module): + def __init__(self, *args): + super().__init__() + for block in args: + self._modules[block] = block + + def forward(self, X): + for block in self._modules.values(): + X = block(X) + return X + + +if __name__ == '__main__': + X = torch.rand(2, 20) + net = MySeq(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10)) + net(X) + diff --git a/learn_torch/seq/char_rnn.py b/learn_torch/seq/char_rnn.py index 82cc924..552266b 100644 --- a/learn_torch/seq/char_rnn.py +++ b/learn_torch/seq/char_rnn.py @@ -4,6 +4,8 @@ import glob import os import unicodedata import string +import torch +from torch import nn def find_files(path): @@ -31,11 +33,71 @@ def build_vocab(filepath): return category_lines, all_categories +def letter_to_index(letter): + return all_letters.index(letter) + + +def letter_to_tensor(letter): + tensor = torch.zeros(1, n_letters) + tensor[0][letter_to_index(letter)] = 1 + return tensor + + +def line_to_tensor(line): + tenor = torch.zeros(len(line), 1, n_letters) + for i, letter in enumerate(line): + tenor[i][0] = letter_to_tensor(letter) + return tenor + + +class RNN(torch.nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(RNN, self).__init__() + self.hidden_size = hidden_size + self.i2h = torch.nn.Linear(input_size + hidden_size, hidden_size) + self.i2o = torch.nn.Linear(input_size + hidden_size, output_size) + self.softmax = torch.nn.LogSoftmax(dim=1) + + def forward(self, input, hidden): + combined = torch.cat((input, hidden), 1) + hidden = self.i2h(combined) + output = self.i2o(combined) + output = self.softmax(output) + return output, hidden + + def init_hidden(self): + return torch.zeros(1, self.hidden_size) + + +def category_from_output(output): + top_v, top_i = output.topk(1) + category_i = top_i[0].item() + return category_i, all_categories[category_i], top_v.item() + + +def train(x, y): + hidden = rnn.init_hidden() + rnn.zero_grad() + for i in range(x.shape[0]): + output, hidden = rnn.forward(x[i], hidden) + loss = criterion(output, y) + loss.backward() + for p in rnn.parameters(): + p.data.add_(p.grad.data, alpha=-lr) + return output, loss.item() + if __name__ == '__main__': all_letters = string.ascii_letters + " .,;'" n_letters = len(all_letters) - print('n_letters = {}'.format(n_letters)) + category_lines, all_categories = build_vocab('../text_data/names/*.txt') - print(all_categories) + + n_categories = len(all_categories) + n_hidden = 128 + rnn = RNN(n_letters, n_hidden, n_categories) + lr = 1e-5 + criterion = torch.nn.NLLLoss() + + nn.CrossEntropyLoss() -- cgit v1.2.3