From 772631dd80deb9946118e274b8672dee551010a2 Mon Sep 17 00:00:00 2001 From: Paul-Louis NECH Date: Tue, 19 Nov 2019 14:29:54 +0100 Subject: [PATCH] refact(lstm): Fix generate_text, extract params, use PoemTok --- KoozDawa/dawa/lstm.py | 61 ++++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/KoozDawa/dawa/lstm.py b/KoozDawa/dawa/lstm.py index 73e8863..c553801 100644 --- a/KoozDawa/dawa/lstm.py +++ b/KoozDawa/dawa/lstm.py @@ -2,13 +2,14 @@ import warnings import numpy as np from keras import Sequential +from keras.callbacks import ModelCheckpoint, EarlyStopping from keras.layers import Embedding, LSTM, Dropout, Dense from keras.utils import to_categorical from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.text import Tokenizer from KoozDawa.dawa.loader import load_kawa, clean_text, load_seeds -from KoozDawa.dawa.tokens import get_sequence_of_tokens +from KoozDawa.dawa.tokens import PoemTokenizer warnings.filterwarnings("ignore") warnings.simplefilter(action='ignore', category=FutureWarning) @@ -25,7 +26,7 @@ def generate_padded_sequences(input_sequences, total_words): return predictors, label, max_sequence_len -def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TODO finetune +def create_model(max_sequence_len, total_words, layers=128, dropout=0.2): # TODO finetune layers/dropout input_len = max_sequence_len - 1 model = Sequential() @@ -39,7 +40,9 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD # Add Output Layer model.add(Dense(total_words, activation='softmax')) - model.compile(loss='categorical_crossentropy', optimizer='adam') + model.compile(optimizer='adam', # TODO: Try RMSprop(learning_rate=0.01) + loss='categorical_crossentropy', # TODO: Try sparse_categorical_crossentropy for faster training + metrics=['accuracy']) # TODO: Try alternative architectures # https://medium.com/coinmonks/word-level-lstm-text-generator-creating-automatic-song-lyrics-with-neural-networks-b8a1617104fb#35f4 @@ -47,53 +50,61 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_words=5, max_sequence_len=0) -> str: + word_indices = {v: k for k, v in tokenizer.word_index.items()} + output = seed_text + for _ in range(nb_words): - token_list = tokenizer.texts_to_sequences([seed_text])[0] + token_list = tokenizer.texts_to_sequences([output])[0] token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre') - predicted = model.predict_classes(token_list, verbose=2) - - output_word = "" - for word, index in tokenizer.word_index.items(): - if index == predicted: - output_word = word - break - seed_text += " " + output_word - return seed_text.capitalize() + predicted = model.predict_classes(token_list, verbose=2)[0] + output += " " + word_indices[predicted] + return output.capitalize() def main(): - should_train = True + # should_train = True # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch + nb_words = 20 nb_epoch = 100 - nb_words = 200 - tokenizer = Tokenizer() + nb_layers = 128 + dropout = .2 + tokenizer = PoemTokenizer() # if should_train: lines = load_kawa() corpus = [clean_text(x) for x in lines] - print("Corpus:", corpus[:5]) + print("Corpus:", corpus[:10]) - inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer) + inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus) predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) - model = create_model(max_sequence_len, total_words) + model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout) model.summary() - model.fit(predictors, label, epochs=nb_epoch, verbose=5) + file_path = "../models/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) + checkpoint = ModelCheckpoint(file_path, monitor='accuracy', save_best_only=True) + # print_callback = LambdaCallback(on_epoch_end=on_epoch_end) + early_stopping = EarlyStopping(monitor='accuracy', patience=5) + callbacks_list = [checkpoint, early_stopping] + + for i in range(nb_epoch): + model.fit(predictors, label, initial_epoch=i, epochs=i + 1, verbose=2, callbacks=callbacks_list) + # model.save(model_file) - # else: # FIXME: Load and predict + # else: # FIXME: Load and predict, maybe reuse checkpoints? # model = load_model(model_file) - for sample in load_seeds(lines): - print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len)) + for i, seed in enumerate(load_seeds(lines, 3)): + output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len) + print("%i %s -> %s" % (i, seed, output)) - with open("./output/lstm.txt", "a+") as f: + with open("./output/dawa.txt", "a+") as f: while True: input_text = input("> ") text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) print(text) - f.writelines(text) + f.writelines("%s\n" % text) if __name__ == '__main__': -- libgit2 0.27.0