From 170dc893679d5caeaeb8ee292e36ae435d5f3ca8 Mon Sep 17 00:00:00 2001 From: Paul-Louis NECH Date: Sat, 23 Nov 2019 00:04:59 +0100 Subject: [PATCH] feat(kawa): Save/generate every 10 epochs --- KoozDawa/dawa/lstm.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/KoozDawa/dawa/lstm.py b/KoozDawa/dawa/lstm.py index c553801..55bd633 100644 --- a/KoozDawa/dawa/lstm.py +++ b/KoozDawa/dawa/lstm.py @@ -3,7 +3,7 @@ import warnings import numpy as np from keras import Sequential from keras.callbacks import ModelCheckpoint, EarlyStopping -from keras.layers import Embedding, LSTM, Dropout, Dense +from keras.layers import Embedding, LSTM, Dropout, Dense, Bidirectional from keras.utils import to_categorical from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.text import Tokenizer @@ -19,14 +19,14 @@ warnings.simplefilter(action='ignore', category=FutureWarning) def generate_padded_sequences(input_sequences, total_words): max_sequence_len = max([len(x) for x in input_sequences]) input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre')) - - print("Max len:", max_sequence_len) predictors, label = input_sequences[:, :-1], input_sequences[:, -1] label = to_categorical(label, num_classes=total_words) return predictors, label, max_sequence_len -def create_model(max_sequence_len, total_words, layers=128, dropout=0.2): # TODO finetune layers/dropout +def create_model(max_sequence_len, total_words, layers=128, dropout=0.3): # TODO finetune layers/dropout + print("Creating model across %i words for %i-long seqs (%i layers, %.2f dropout):" % + (total_words, max_sequence_len, layers, dropout)) input_len = max_sequence_len - 1 model = Sequential() @@ -35,6 +35,7 @@ def create_model(max_sequence_len, total_words, layers=128, dropout=0.2): # TOD # Add Hidden Layer 1 - LSTM Layer model.add(LSTM(layers)) + # model.add(Bidirectional(LSTM(layers), input_shape=(max_sequence_len, total_words))) model.add(Dropout(dropout)) # Add Output Layer @@ -81,20 +82,21 @@ def main(): model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout) model.summary() - file_path = "../models/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) - checkpoint = ModelCheckpoint(file_path, monitor='accuracy', save_best_only=True) + file_path = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) + checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True) # print_callback = LambdaCallback(on_epoch_end=on_epoch_end) early_stopping = EarlyStopping(monitor='accuracy', patience=5) callbacks_list = [checkpoint, early_stopping] - for i in range(nb_epoch): - model.fit(predictors, label, initial_epoch=i, epochs=i + 1, verbose=2, callbacks=callbacks_list) + for i in range(0, nb_epoch, 10): + model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list) + print(generate_text(model, tokenizer, "", nb_words, max_sequence_len)) # model.save(model_file) # else: # FIXME: Load and predict, maybe reuse checkpoints? # model = load_model(model_file) - for i, seed in enumerate(load_seeds(lines, 3)): + for i, seed in enumerate(load_seeds(lines, 5)): output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len) print("%i %s -> %s" % (i, seed, output)) -- libgit2 0.27.0