diff --git a/KoozDawa/dawa/loader.py b/KoozDawa/dawa/loader.py index a85ecfc..9ccf9df 100644 --- a/KoozDawa/dawa/loader.py +++ b/KoozDawa/dawa/loader.py @@ -28,8 +28,8 @@ def clean_text(lines): In dataset preparation step, we will first perform text cleaning of the data which includes removal of punctuations and lower casing all the words. """ - lines = "".join(v for v in lines if v not in string.punctuation).lower() - lines = lines.encode("utf8").decode("ascii", 'ignore') + lines = "".join(v for v in lines if v not in string.punctuation) + # lines = lines.encode("utf8").decode("ascii", 'ignore') return lines diff --git a/KoozDawa/dawa/lstm.py b/KoozDawa/dawa/lstm.py index 02e558f..478fe19 100644 --- a/KoozDawa/dawa/lstm.py +++ b/KoozDawa/dawa/lstm.py @@ -2,7 +2,6 @@ import warnings import numpy as np from keras import Sequential -from keras.engine.saving import load_model from keras.layers import Embedding, LSTM, Dropout, Dense from keras.utils import to_categorical from keras_preprocessing.sequence import pad_sequences @@ -57,14 +56,14 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0 output_word = word break seed_text += " " + output_word - return seed_text.title() + return seed_text.capitalize() def main(): should_train = True nb_epoch = 100 max_sequence_len = 61 # TODO: Test different default - model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch + # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch tokenizer = Tokenizer() if should_train: @@ -79,21 +78,24 @@ def main(): model.summary() model.fit(predictors, label, epochs=nb_epoch, verbose=5) - model.save(model_file) - else: - model = load_model(model_file) + # model.save(model_file) + # else: # FIXME: Load and predict + # model = load_model(model_file) for sample in ["", "L'étoile du sol", "Elle me l'a toujours dit", "Les punchlines sont pour ceux"]: - nb_words = 50 + nb_words = 200 print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len)) - while True: - input_text = input("> ") - print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)) - print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)) + with open("../output/lstm.txt", "a") as f: + while True: + input_text = input("> ") + text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) + + print(text) + f.writelines(text) if __name__ == '__main__':