diff --git a/KoozDawa/dawa/lstm.py b/KoozDawa/dawa/lstm.py index 4c3014f..02e558f 100644 --- a/KoozDawa/dawa/lstm.py +++ b/KoozDawa/dawa/lstm.py @@ -49,7 +49,7 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0 for _ in range(nb_words): token_list = tokenizer.texts_to_sequences([seed_text])[0] token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre') - predicted = model.predict_classes(token_list, verbose=0) + predicted = model.predict_classes(token_list, verbose=2) output_word = "" for word, index in tokenizer.word_index.items(): @@ -63,22 +63,18 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0 def main(): should_train = True nb_epoch = 100 + max_sequence_len = 61 # TODO: Test different default model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch - max_sequence_len = 5 # TODO: Test different default tokenizer = Tokenizer() if should_train: lines = load_kawa() corpus = [clean_text(x) for x in lines] - print(corpus[:10]) + print("Corpus:", corpus[:2]) inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer) - print(inp_sequences[:10]) - predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) - print(predictors, label, max_sequence_len) - model = create_model(max_sequence_len, total_words) model.summary() @@ -87,12 +83,17 @@ def main(): else: model = load_model(model_file) - for sample in ["", "L'étoile ", "Elle ", "Les punchlines "]: - print(generate_text(model, tokenizer, sample, 100, max_sequence_len)) + for sample in ["", + "L'étoile du sol", + "Elle me l'a toujours dit", + "Les punchlines sont pour ceux"]: + nb_words = 50 + print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len)) while True: input_text = input("> ") - print(generate_text(model, tokenizer, input_text, 100, max_sequence_len)) + print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)) + print(generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)) if __name__ == '__main__': diff --git a/KoozDawa/dawa/tokens.py b/KoozDawa/dawa/tokens.py index 481f5fe..3acfb5a 100644 --- a/KoozDawa/dawa/tokens.py +++ b/KoozDawa/dawa/tokens.py @@ -11,8 +11,6 @@ def get_sequence_of_tokens(corpus, tokenizer=Tokenizer()): # convert data to sequence of tokens input_sequences = [] - # FIXME Debug: truncate corpus - corpus = corpus[:50] for line in corpus: token_list = tokenizer.texts_to_sequences([line])[0] for i in range(1, len(token_list)):