summaryrefslogtreecommitdiff
path: root/learn_torch/seq/char_rnn.py
blob: 552266bb6c3ebe97fb791b2cb5196378d58d2505 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

from io import open
import glob
import os
import unicodedata
import string
import torch
from torch import nn


def find_files(path):
    return glob.glob(path)


def uni_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s)
                    if unicodedata.category(c) != 'Mn' and c in all_letters)


def build_vocab(filepath):
    def read_lines(filename):
        lines = open(filename, encoding='utf-8').read().strip().split('\n')
        return [uni_to_ascii(line) for line in lines]

    category_lines = {}
    all_categories = []

    for filename in find_files(filepath):
        category = os.path.splitext(os.path.basename(filename))[0]
        lines = read_lines(filename)
        all_categories.append(category)
        category_lines[category] = lines
    return category_lines, all_categories


def letter_to_index(letter):
    return all_letters.index(letter)


def letter_to_tensor(letter):
    tensor = torch.zeros(1, n_letters)
    tensor[0][letter_to_index(letter)] = 1
    return tensor


def line_to_tensor(line):
    tenor = torch.zeros(len(line), 1, n_letters)
    for i, letter in enumerate(line):
        tenor[i][0] = letter_to_tensor(letter)
    return tenor


class RNN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.i2h = torch.nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = torch.nn.Linear(input_size + hidden_size, output_size)
        self.softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)


def category_from_output(output):
    top_v, top_i = output.topk(1)
    category_i = top_i[0].item()
    return category_i, all_categories[category_i], top_v.item()


def train(x, y):
    hidden = rnn.init_hidden()
    rnn.zero_grad()
    for i in range(x.shape[0]):
        output, hidden = rnn.forward(x[i], hidden)
    loss = criterion(output, y)
    loss.backward()
    for p in rnn.parameters():
        p.data.add_(p.grad.data, alpha=-lr)
    return output, loss.item()


if __name__ == '__main__':
    all_letters = string.ascii_letters + " .,;'"
    n_letters = len(all_letters)

    category_lines, all_categories = build_vocab('../text_data/names/*.txt')

    n_categories = len(all_categories)
    n_hidden = 128
    rnn = RNN(n_letters, n_hidden, n_categories)
    lr = 1e-5
    criterion = torch.nn.NLLLoss()

    nn.CrossEntropyLoss()