From 42b38e3e6206818308d528b693ced1f570254ce6 Mon Sep 17 00:00:00 2001 From: Paul-Louis NECH Date: Tue, 26 Nov 2019 16:27:45 +0100 Subject: [PATCH] refactor(dawa): Generalize LSTM/Tweeper --- KoozDawa/dawa.py | 76 +++++++++++++++++++++++++++++++++------------------------------------------- KoozDawa/tweeper.py | 38 -------------------------------------- KoozDawa/tweet.py | 24 ++++++++++++++++++++++++ glossolalia/lstm.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---- glossolalia/tweeper.py | 33 +++++++++++++++++++++++++++++++++ 5 files changed, 153 insertions(+), 85 deletions(-) delete mode 100755 KoozDawa/tweeper.py create mode 100644 KoozDawa/tweet.py create mode 100755 glossolalia/tweeper.py diff --git a/KoozDawa/dawa.py b/KoozDawa/dawa.py index 5d09c59..8cb0142 100644 --- a/KoozDawa/dawa.py +++ b/KoozDawa/dawa.py @@ -1,65 +1,55 @@ +from datetime import datetime + from keras.callbacks import ModelCheckpoint, EarlyStopping from glossolalia.loader import load_seeds, load_text -from glossolalia.lstm import generate_padded_sequences, create_model, generate_text -from glossolalia.tokens import PoemTokenizer +from glossolalia.lstm import LisSansTaMaman -def main(): +def train(): # should_train = True # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch nb_words = 20 - nb_epoch = 50 - nb_layers = 64 - dropout = .2 - tokenizer = PoemTokenizer() + nb_epoch = 100 + nb_layers = 100 + dropout = .3 # TODO finetune layers/dropout + validation_split = 0.2 + lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True) + filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % ( + nb_layers, dropout, nb_epoch) + filename_output = "./output/dawa_%i-d%.1f_%s.txt" % ( + nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M")) + callbacks_list = [ + ModelCheckpoint(filename_model, monitor='val_accuracy', period=10, save_best_only=True), + EarlyStopping(monitor='val_accuracy', patience=5)] - # if should_train: corpus = load_text() print("Corpus:", corpus[:10]) + lstm.create_model(corpus[:1000]) + with open(filename_output, "a+") as f: + for i in range(0, nb_epoch, 10): + lstm.fit(epochs=min(i + 10, nb_epoch), initial_epoch=i, + callbacks=callbacks_list, + validation_split=validation_split) - inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus) - predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) - model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout) - model.summary() - - file_path = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) - checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True) - # print_callback = LambdaCallback(on_epoch_end=on_epoch_end) - early_stopping = EarlyStopping(monitor='accuracy', patience=5) - callbacks_list = [checkpoint, early_stopping] + for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]: + print(lstm.predict(seed, nb_words)) - for i in range(0, nb_epoch, 10): - model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list) - for seed in ["", "Je", "Tu", "Le", "La", "Les", "Un", "On", "Nous"]: - print(generate_text(model, tokenizer, seed, nb_words, max_sequence_len)) + # model.save(model_file) + # else: # FIXME: Load and predict, maybe reuse checkpoints? + # model = load_model(model_file) - # model.save(model_file) - # else: # FIXME: Load and predict, maybe reuse checkpoints? - # model = load_model(model_file) + for i, seed in enumerate(load_seeds(corpus, 5)): + output = lstm.predict(seed, nb_words) + print("%i %s -> %s" % (i, seed, output)) + f.writelines(output) - for i, seed in enumerate(load_seeds(corpus, 5)): - output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len) - print("%i %s -> %s" % (i, seed, output)) - - with open("./output/dawa.txt", "a+") as f: while True: input_text = input("> ") - text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) - + text = lstm.predict(input_text, nb_words) print(text) f.writelines("%s\n" % text) -def debug_unrandomize(): - from numpy.random import seed - from tensorflow_core.python.framework.random_seed import set_random_seed - - # set seeds for reproducibility - set_random_seed(2) - seed(1) - - if __name__ == '__main__': - debug_unrandomize() - main() + train() diff --git a/KoozDawa/tweeper.py b/KoozDawa/tweeper.py deleted file mode 100755 index 9deff21..0000000 --- a/KoozDawa/tweeper.py +++ /dev/null @@ -1,38 +0,0 @@ -#! /usr/bin/env python -import os -import time -import tweepy -from didyoumean3.didyoumean import did_you_mean - - -class Tweeper(object): - - def __init__(self): - auth = tweepy.OAuthHandler( - os.environ["ZOO_DAWA_KEY"], - os.environ["ZOO_DAWA_KEY_SECRET"]) - auth.set_access_token( - os.environ["ZOO_DAWA_TOKEN"], - os.environ["ZOO_DAWA_TOKEN_SECRET"]) - self.api = tweepy.API(auth) - - def tweet(self, message): - """Tweets a message after spellchecking it.""" - message = did_you_mean(message) - print("About to tweet:", message) - time.sleep(5) - self.api.update_status(message) - - -def main(): - Tweeper().tweet("le business réel de la saint-valentin") -# Nous la nuit de la renaissance j’étais la tête - -# Authenticate to Twitter -# tassepés en panel -# grands brûlés de la chine -# La nuit est belle, ma chérie salue sur la capuche -# Je suis pas étonné de dire pétrin -# Femme qui crame strasbourg -if __name__ == '__main__': - main() diff --git a/KoozDawa/tweet.py b/KoozDawa/tweet.py new file mode 100644 index 0000000..04260c8 --- /dev/null +++ b/KoozDawa/tweet.py @@ -0,0 +1,24 @@ +from glossolalia.tweeper import Tweeper + + +def tweet(): + # La nuit est belle, ma chérie salue sur la capuche + # grands brûlés de la chine + # Femme qui crame strasbourg + # le soleil est triste + # on a pas un martyr parce qu't'es la + # des neiges d'insuline + # une hypothèse qu'engendre la haine n'est qu'une prison vide + # Un jour de l'an commencé sur les autres + # Relater l'passionnel dans les casseroles d'eau de marécages + # sniff de Caravage rapide + # La nuit c'est le soleil + + # Les rues d'ma vie se terminent par des partouzes de ciel + # des glaçons pour les yeux brisées + + Tweeper("KoozDawa").tweet("tassepés en panel") + + +if __name__ == '__main__': + tweet() diff --git a/glossolalia/lstm.py b/glossolalia/lstm.py index 652b6c5..e8c00b2 100644 --- a/glossolalia/lstm.py +++ b/glossolalia/lstm.py @@ -1,18 +1,77 @@ import warnings +from typing import List import numpy as np -from keras import Sequential +from keras import Sequential, Model +from keras.callbacks import Callback, History from keras.layers import Embedding, LSTM, Dropout, Dense from keras.utils import to_categorical from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.text import Tokenizer +from glossolalia.tokens import PoemTokenizer + warnings.filterwarnings("ignore") warnings.simplefilter(action='ignore', category=FutureWarning) -# 3.3 Padding the Sequences and obtain Variables : Predictors and Target -def generate_padded_sequences(input_sequences, total_words): +def debug_unrandomize(): + from numpy.random import seed + from tensorflow_core.python.framework.random_seed import set_random_seed + + # set seeds for reproducibility + set_random_seed(2) + seed(1) + + +class LisSansTaMaman(object): + """ A LSTM model adapted for french lyrical texts.""" + + def __init__(self, nb_layers: int = 100, + dropout: float = 0.1, validation_split: float = 0.0, + tokenizer=PoemTokenizer(), + debug: bool = False): + self.validation_split = validation_split + self.dropout = dropout + self.nb_layers = nb_layers + self.tokenizer = tokenizer + + # Model state + self.model: Model = None + self.predictors = None + self.label = None + self.max_sequence_len = None + + if debug: + debug_unrandomize() + + def create_model(self, corpus: List[str]): + inp_sequences, total_words = self.tokenizer.get_sequence_of_tokens(corpus) + + self.predictors, self.label, self.max_sequence_len = generate_padded_sequences(inp_sequences, total_words) + model = create_model(self.max_sequence_len, total_words, layers=self.nb_layers, dropout=self.dropout) + model.summary() + + self.model = model + + # TODO: Batch fit? splitting nb_epoch into N step + def fit(self, epochs: int, initial_epoch: int = 0, + callbacks: List[Callback] = None, + validation_split: float = 0 + ) -> History: + return self.model.fit(self.predictors, self.label, + verbose=2, + callbacks=callbacks, + validation_split=validation_split, + epochs=epochs, initial_epoch=initial_epoch) + + def predict(self, seed="", nb_words=None): + if nb_words is None: + nb_words = 20 # TODO: Guess based on model a good number of words + return generate_text(self.model, self.tokenizer, seed, nb_words, self.max_sequence_len) + + +def generate_padded_sequences(input_sequences, total_words: int): max_sequence_len = max([len(x) for x in input_sequences]) input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre')) predictors, label = input_sequences[:, :-1], input_sequences[:, -1] @@ -20,7 +79,7 @@ def generate_padded_sequences(input_sequences, total_words): return predictors, label, max_sequence_len -def create_model(max_sequence_len, total_words, layers=128, dropout=0.3): # TODO finetune layers/dropout +def create_model(max_sequence_len: int, total_words: int, layers: int, dropout: float): print("Creating model across %i words for %i-long seqs (%i layers, %.2f dropout):" % (total_words, max_sequence_len, layers, dropout)) input_len = max_sequence_len - 1 diff --git a/glossolalia/tweeper.py b/glossolalia/tweeper.py new file mode 100755 index 0000000..c6c9014 --- /dev/null +++ b/glossolalia/tweeper.py @@ -0,0 +1,33 @@ +#! /usr/bin/env python +import os +import time + +import tweepy +from didyoumean3.didyoumean import did_you_mean +from tweepy import Cursor + + +class Tweeper(object): + def __init__(self, name: str): + auth = tweepy.OAuthHandler( + os.environ["ZOO_DAWA_KEY"], + os.environ["ZOO_DAWA_KEY_SECRET"]) + auth.set_access_token( + os.environ["ZOO_DAWA_TOKEN"], + os.environ["ZOO_DAWA_TOKEN_SECRET"]) + self.api = tweepy.API(auth) + self.name = name + + @property + def all_tweets(self): + return [t.text for t in Cursor(self.api.user_timeline, id=self.name).items()] + + def tweet(self, message, wait_delay=5, prevent_duplicate=True): + """Tweets a message after spellchecking it.""" + if prevent_duplicate and message in self.all_tweets: + print("Was already tweeted: %s." % message) + else: + message = did_you_mean(message) + print("About to tweet:", message) + time.sleep(wait_delay) + self.api.update_status(message) \ No newline at end of file -- libgit2 0.27.0