diff --git a/.gitignore b/.gitignore index 92a65a9..c26a768 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # IDE .idea/ + +# Outputs +output/ diff --git a/KoozDawa/data/apocalypse b/KoozDawa/data/apocalypse.txt similarity index 100% rename from KoozDawa/data/apocalypse rename to KoozDawa/data/apocalypse.txt diff --git a/KoozDawa/dawa/loader.py b/KoozDawa/dawa/loader.py index 9ccf9df..682d16c 100644 --- a/KoozDawa/dawa/loader.py +++ b/KoozDawa/dawa/loader.py @@ -1,5 +1,7 @@ import os import string +from pprint import pprint +from random import choice, randint from numpy.random import seed from tensorflow_core.python.framework.random_seed import set_random_seed @@ -11,7 +13,9 @@ def load_kawa(root="./"): seed(1) data_dir = root + 'data/' all_lines = [] - for filename in os.listdir(data_dir): + files = os.listdir(data_dir) + print("%i files in data folder." % len(files)) + for filename in files: with open(data_dir + filename) as f: content = f.readlines() all_lines.extend(content) @@ -23,6 +27,19 @@ def load_kawa(root="./"): return all_lines +def load_seeds(kawa=None, nb_seeds=10): + if kawa is None: + kawa = load_kawa() + seeds = [] + for i in range(nb_seeds): + plain_kawa = filter(lambda k: k != "\n", kawa) + chosen = choice(list(plain_kawa)) + split = chosen.split(" ") + nb_words = randint(1, len(split)) + seeds.append(split[:nb_words]) + return seeds + + def clean_text(lines): """ In dataset preparation step, we will first perform text cleaning of the data @@ -37,6 +54,8 @@ def main(): lines = load_kawa("../") clean = clean_text(lines) print(clean) + print("Some seeds:\n\n") + pprint(load_seeds(lines)) if __name__ == '__main__': diff --git a/KoozDawa/dawa/lstm.py b/KoozDawa/dawa/lstm.py index 478fe19..ed2c0f7 100644 --- a/KoozDawa/dawa/lstm.py +++ b/KoozDawa/dawa/lstm.py @@ -7,7 +7,7 @@ from keras.utils import to_categorical from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.text import Tokenizer -from KoozDawa.dawa.loader import load_kawa, clean_text +from KoozDawa.dawa.loader import load_kawa, clean_text, load_seeds from KoozDawa.dawa.tokens import get_sequence_of_tokens warnings.filterwarnings("ignore") @@ -61,35 +61,31 @@ def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0 def main(): should_train = True - nb_epoch = 100 - max_sequence_len = 61 # TODO: Test different default # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch + nb_epoch = 100 + nb_words = 200 tokenizer = Tokenizer() - if should_train: - lines = load_kawa() + # if should_train: + lines = load_kawa() - corpus = [clean_text(x) for x in lines] - print("Corpus:", corpus[:2]) + corpus = [clean_text(x) for x in lines] + print("Corpus:", corpus[:5]) - inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer) - predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) - model = create_model(max_sequence_len, total_words) - model.summary() + inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer) + predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) + model = create_model(max_sequence_len, total_words) + model.summary() - model.fit(predictors, label, epochs=nb_epoch, verbose=5) - # model.save(model_file) + model.fit(predictors, label, epochs=nb_epoch, verbose=5) + # model.save(model_file) # else: # FIXME: Load and predict - # model = load_model(model_file) + # model = load_model(model_file) - for sample in ["", - "L'étoile du sol", - "Elle me l'a toujours dit", - "Les punchlines sont pour ceux"]: - nb_words = 200 + for sample in load_seeds(lines): print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len)) - with open("../output/lstm.txt", "a") as f: + with open("./output/lstm.txt", "a+") as f: while True: input_text = input("> ") text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) diff --git a/KoozDawa/output/selection.txt b/KoozDawa/output/selection.txt new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/KoozDawa/output/selection.txt