From 61a1d7a9a37ccf40fa7404f69418a13fbbed592a Mon Sep 17 00:00:00 2001 From: Paul-Louis NECH Date: Sun, 17 Nov 2019 15:42:32 +0100 Subject: [PATCH] feat(lstm): refact, predict, nocomment --- KoozDawa/lstm.py | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/KoozDawa/lstm.py b/KoozDawa/lstm.py index 2d5eccb..c55f141 100644 --- a/KoozDawa/lstm.py +++ b/KoozDawa/lstm.py @@ -4,6 +4,7 @@ import warnings import numpy as np from keras import Sequential +from keras.engine.saving import load_model from keras.layers import Embedding, LSTM, Dropout, Dense from keras.preprocessing.text import Tokenizer from keras.utils import to_categorical @@ -26,10 +27,10 @@ def load(): content = f.readlines() all_lines.extend(content) - all_lines = [h for h in all_lines if - h[0] != "["] + all_lines = [h for h in all_lines if h[0] not in ["[", "#"] + ] len(all_lines) - print("Loaded data:", all_lines[0]) + print("Loaded %i lines of data: %s." % (len(all_lines), all_lines[0])) return all_lines @@ -78,7 +79,7 @@ def generate_padded_sequences(input_sequences, total_words): return predictors, label, max_sequence_len -def create_model(max_sequence_len, total_words): +def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TODO finetune input_len = max_sequence_len - 1 model = Sequential() @@ -86,8 +87,8 @@ def create_model(max_sequence_len, total_words): model.add(Embedding(total_words, 10, input_length=input_len)) # Add Hidden Layer 1 - LSTM Layer - model.add(LSTM(100)) # TODO finetune - model.add(Dropout(0.1)) # TODO finetune + model.add(LSTM(layers)) + model.add(Dropout(dropout)) # Add Output Layer model.add(Dense(total_words, activation='softmax')) @@ -113,25 +114,38 @@ def generate_text(seed_text, nb_words, model, max_sequence_len): def main(): - lines = load() + should_train = True + nb_epoch = 20 + model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch + max_sequence_len = 5 # TODO: Test different default - corpus = [clean_text(x) for x in lines] - print(corpus[:10]) + if should_train: + lines = load() - inp_sequences, total_words = get_sequence_of_tokens(corpus[:10]) # Fixme: Corpus cliff for debug - print(inp_sequences[:10]) + corpus = [clean_text(x) for x in lines] + print(corpus[:10]) - predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) - print(predictors, label, max_sequence_len) + inp_sequences, total_words = get_sequence_of_tokens(corpus) + print(inp_sequences[:10]) - model = create_model(max_sequence_len, total_words) - model.summary() + predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) + print(predictors, label, max_sequence_len) - model.fit(predictors, label, epochs=10, verbose=5) + model = create_model(max_sequence_len, total_words) + model.summary() + + model.fit(predictors, label, epochs=nb_epoch, verbose=5) + model.save(model_file) + else: + model = load_model(model_file) print(generate_text("", 10, model, max_sequence_len)) print(generate_text("L'étoile", 10, model, max_sequence_len)) + while True: + input_text = input("> ") + print(generate_text(input_text, 10, model, max_sequence_len)) + if __name__ == '__main__': main() -- libgit2 0.27.0