diff options
| author | zhang <zch921005@126.com> | 2022-03-25 00:11:34 +0800 |
|---|---|---|
| committer | zhang <zch921005@126.com> | 2022-03-25 00:11:34 +0800 |
| commit | 3ed1c5d0e4018fc70012a6209a859a059f7127b5 (patch) | |
| tree | 2ca318fb2e23daf2964189a751fb61b20aa7515a /network/lstm_add.py | |
| parent | 6f68e1818229e0d2dad760062e6b5bb137b88f5b (diff) | |
更新脚本
Diffstat (limited to 'network/lstm_add.py')
| -rw-r--r-- | network/lstm_add.py | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/network/lstm_add.py b/network/lstm_add.py new file mode 100644 index 0000000..1eeb4f9 --- /dev/null +++ b/network/lstm_add.py @@ -0,0 +1,138 @@ +from random import seed +from random import randint +from numpy import array +from math import ceil +from math import log10 +from math import sqrt +from numpy import argmax +from keras.models import Sequential +from keras.layers import Dense +from keras.layers import LSTM +from keras.layers import TimeDistributed +from keras.layers import RepeatVector + + +# generate lists of random integers and their sum +def random_sum_pairs(n_examples, n_numbers, largest): + X, y = list(), list() + for i in range(n_examples): + in_pattern = [randint(1, largest) for _ in range(n_numbers)] + out_pattern = sum(in_pattern) + X.append(in_pattern) + y.append(out_pattern) + return X, y + + +# convert data to strings +def to_string(X, y, n_numbers, largest): + max_length = n_numbers * ceil(log10(largest + 1)) + n_numbers - 1 + Xstr = list() + for pattern in X: + strp = '+'.join([str(n) for n in pattern]) + strp = ''.join([' ' for _ in range(max_length - len(strp))]) + strp + Xstr.append(strp) + max_length = ceil(log10(n_numbers * (largest + 1))) + ystr = list() + for pattern in y: + strp = str(pattern) + strp = ''.join([' ' for _ in range(max_length - len(strp))]) + strp + ystr.append(strp) + return Xstr, ystr + + +# integer encode strings +def integer_encode(X, y, alphabet): + char_to_int = dict((c, i) for i, c in enumerate(alphabet)) + Xenc = list() + for pattern in X: + integer_encoded = [char_to_int[char] for char in pattern] + Xenc.append(integer_encoded) + yenc = list() + for pattern in y: + integer_encoded = [char_to_int[char] for char in pattern] + yenc.append(integer_encoded) + return Xenc, yenc + + +# one hot encode +def one_hot_encode(X, y, max_int): + Xenc = list() + for seq in X: + pattern = list() + for index in seq: + vector = [0 for _ in range(max_int)] + vector[index] = 1 + pattern.append(vector) + Xenc.append(pattern) + yenc = list() + for seq in y: + pattern = list() + for index in seq: + vector = [0 for _ in range(max_int)] + vector[index] = 1 + pattern.append(vector) + yenc.append(pattern) + return Xenc, yenc + + +# generate an encoded dataset +def generate_data(n_samples, n_numbers, largest, alphabet): + # generate pairs + X, y = random_sum_pairs(n_samples, n_numbers, largest) + # convert to strings + X, y = to_string(X, y, n_numbers, largest) + # integer encode + X, y = integer_encode(X, y, alphabet) + # one hot encode + X, y = one_hot_encode(X, y, len(alphabet)) + # return as numpy arrays + X, y = array(X), array(y) + return X, y + + +# invert encoding +def invert(seq, alphabet): + int_to_char = dict((i, c) for i, c in enumerate(alphabet)) + strings = list() + for pattern in seq: + string = int_to_char[argmax(pattern)] + strings.append(string) + return ''.join(strings) + + +# define dataset +seed(1) +n_samples = 5000 +n_numbers = 2 +largest = 100 +alphabet = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', ' '] +n_chars = len(alphabet) +n_in_seq_length = n_numbers * ceil(log10(largest + 1)) + n_numbers - 1 +n_out_seq_length = ceil(log10(n_numbers * (largest + 1))) +# define LSTM configuration +n_batch = 10 +n_epoch = 50 +# create LSTM +model = Sequential() +# n_in_seq_length: 7, n_chars: 12 +model.add(LSTM(100, input_shape=(n_in_seq_length, n_chars))) +model.add(RepeatVector(n_out_seq_length)) +model.add(LSTM(50, return_sequences=True)) +model.add(TimeDistributed(Dense(n_chars, activation='softmax'))) +model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) +print(model.summary()) +# train LSTM +for i in range(n_epoch): + X, y = generate_data(n_samples, n_numbers, largest, alphabet) + print(i) + model.fit(X, y, epochs=1, batch_size=n_batch) + +# evaluate on some new patterns +X, y = generate_data(n_samples, n_numbers, largest, alphabet) +result = model.predict(X, batch_size=n_batch, verbose=0) +# calculate error +expected = [invert(x, alphabet) for x in y] +predicted = [invert(x, alphabet) for x in result] +# show some examples +for i in range(20): + print('Expected=%s, Predicted=%s' % (expected[i], predicted[i])) |
