From 380e23b1c9af3f857c3a65ed343459b2393e66d0 Mon Sep 17 00:00:00 2001 From: Paul-Louis NECH Date: Tue, 26 Nov 2019 16:50:40 +0100 Subject: [PATCH] refactor(dawa): Move unrandomize to lstm --- KoozDawa/dawa.py | 10 ++-------- KoozDawa/tweet.py | 6 ++++-- glossolalia/lstm.py | 23 ++++++++++++++--------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/KoozDawa/dawa.py b/KoozDawa/dawa.py index 8cb0142..11f1222 100644 --- a/KoozDawa/dawa.py +++ b/KoozDawa/dawa.py @@ -15,8 +15,7 @@ def train(): dropout = .3 # TODO finetune layers/dropout validation_split = 0.2 lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True) - filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % ( - nb_layers, dropout, nb_epoch) + filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) filename_output = "./output/dawa_%i-d%.1f_%s.txt" % ( nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M")) callbacks_list = [ @@ -32,12 +31,7 @@ def train(): callbacks=callbacks_list, validation_split=validation_split) - for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]: - print(lstm.predict(seed, nb_words)) - - # model.save(model_file) - # else: # FIXME: Load and predict, maybe reuse checkpoints? - # model = load_model(model_file) + print(lstm.predict_seeds(nb_words)) for i, seed in enumerate(load_seeds(corpus, 5)): output = lstm.predict(seed, nb_words) diff --git a/KoozDawa/tweet.py b/KoozDawa/tweet.py index 04260c8..0799730 100644 --- a/KoozDawa/tweet.py +++ b/KoozDawa/tweet.py @@ -8,14 +8,16 @@ def tweet(): # le soleil est triste # on a pas un martyr parce qu't'es la # des neiges d'insuline - # une hypothèse qu'engendre la haine n'est qu'une prison vide # Un jour de l'an commencé sur les autres - # Relater l'passionnel dans les casseroles d'eau de marécages + # une hypothèse qu'engendre la haine n'est qu'une prison vide # sniff de Caravage rapide + # Relater l'passionnel dans les casseroles d'eau de marécages # La nuit c'est le soleil + # Les rues d'ma vie se terminent par la cannelle # Les rues d'ma vie se terminent par des partouzes de ciel # des glaçons pour les yeux brisées + # je suis pas juste un verbe que t'observe Tweeper("KoozDawa").tweet("tassepés en panel") diff --git a/glossolalia/lstm.py b/glossolalia/lstm.py index e8c00b2..51dec94 100644 --- a/glossolalia/lstm.py +++ b/glossolalia/lstm.py @@ -15,15 +15,6 @@ warnings.filterwarnings("ignore") warnings.simplefilter(action='ignore', category=FutureWarning) -def debug_unrandomize(): - from numpy.random import seed - from tensorflow_core.python.framework.random_seed import set_random_seed - - # set seeds for reproducibility - set_random_seed(2) - seed(1) - - class LisSansTaMaman(object): """ A LSTM model adapted for french lyrical texts.""" @@ -65,6 +56,11 @@ class LisSansTaMaman(object): validation_split=validation_split, epochs=epochs, initial_epoch=initial_epoch) + def predict_seeds(self, seeds: List[str] = None, nb_words=None): + if seeds is None: + seeds = ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"] + return [self.predict(seed, nb_words) for seed in seeds] + def predict(self, seed="", nb_words=None): if nb_words is None: nb_words = 20 # TODO: Guess based on model a good number of words @@ -115,3 +111,12 @@ def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_word predicted = model.predict_classes(token_list, verbose=2)[0] output += " " + word_indices[predicted] return output.capitalize() + + +def debug_unrandomize(): + from numpy.random import seed + from tensorflow_core.python.framework.random_seed import set_random_seed + + # set seeds for reproducibility + set_random_seed(2) + seed(1) -- libgit2 0.27.0