diff --git a/KoozDawa/dawa.py b/KoozDawa/dawa.py new file mode 100644 index 0000000..dca40a6 --- /dev/null +++ b/KoozDawa/dawa.py @@ -0,0 +1,66 @@ +from keras.callbacks import ModelCheckpoint, EarlyStopping + +from glossolalia.loader import load_kawa, clean_text, load_seeds +from glossolalia.lstm import generate_padded_sequences, create_model, generate_text +from glossolalia.tokens import PoemTokenizer + + +def main(): + # should_train = True + # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch + nb_words = 20 + nb_epoch = 100 + nb_layers = 128 + dropout = .2 + tokenizer = PoemTokenizer() + + # if should_train: + lines = load_kawa() + + corpus = [clean_text(x) for x in lines] + print("Corpus:", corpus[:10]) + + inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus) + predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) + model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout) + model.summary() + + file_path = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) + checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True) + # print_callback = LambdaCallback(on_epoch_end=on_epoch_end) + early_stopping = EarlyStopping(monitor='accuracy', patience=5) + callbacks_list = [checkpoint, early_stopping] + + for i in range(0, nb_epoch, 10): + model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list) + print(generate_text(model, tokenizer, "", nb_words, max_sequence_len)) + + # model.save(model_file) + # else: # FIXME: Load and predict, maybe reuse checkpoints? + # model = load_model(model_file) + + for i, seed in enumerate(load_seeds(lines, 5)): + output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len) + print("%i %s -> %s" % (i, seed, output)) + + with open("./output/dawa.txt", "a+") as f: + while True: + input_text = input("> ") + text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) + + print(text) + f.writelines("%s\n" % text) + + +def debug_unrandomize(): + from numpy.random import seed + from tensorflow_core.python.framework.random_seed import set_random_seed + + # set seeds for reproducibility + set_random_seed(2) + seed(1) + + +if __name__ == '__main__': + debug_unrandomize() + main() diff --git a/KoozDawa/dawa/lyrics.py b/KoozDawa/lyrics.py similarity index 95% rename from KoozDawa/dawa/lyrics.py rename to KoozDawa/lyrics.py index 27426d3..1b49ff1 100644 --- a/KoozDawa/dawa/lyrics.py +++ b/KoozDawa/lyrics.py @@ -4,6 +4,7 @@ import lyricsgenius def fetch(): genius = lyricsgenius.Genius("zUSpjfQ9ELXDqOjx9hGfAlJGYQFrNvHh3rlDV298_QSr5ScKf3qlHZtOO2KsXspQ") response = genius.search_artist("Dooz-kawa") + print(response) for hit in response["hits"]: print(hit) diff --git a/KoozDawa/tweeper.py b/KoozDawa/tweeper.py index 657b50e..2629baf 100755 --- a/KoozDawa/tweeper.py +++ b/KoozDawa/tweeper.py @@ -25,7 +25,7 @@ class Tweeper(object): def main(): - Tweeper().tweet("un pont de paris sen souvient sur de toi") + Tweeper().tweet("les anges se sont fichés") # Authenticate to Twitter diff --git a/KoozDawa/dawa/loader.py b/glossolalia/loader.py similarity index 100% rename from KoozDawa/dawa/loader.py rename to glossolalia/loader.py diff --git a/KoozDawa/dawa/lstm.py b/glossolalia/lstm.py similarity index 68% rename from KoozDawa/dawa/lstm.py rename to glossolalia/lstm.py index 55bd633..973f017 100644 --- a/KoozDawa/dawa/lstm.py +++ b/glossolalia/lstm.py @@ -2,15 +2,11 @@ import warnings import numpy as np from keras import Sequential -from keras.callbacks import ModelCheckpoint, EarlyStopping -from keras.layers import Embedding, LSTM, Dropout, Dense, Bidirectional +from keras.layers import Embedding, LSTM, Dropout, Dense from keras.utils import to_categorical from keras_preprocessing.sequence import pad_sequences from keras_preprocessing.text import Tokenizer -from KoozDawa.dawa.loader import load_kawa, clean_text, load_seeds -from KoozDawa.dawa.tokens import PoemTokenizer - warnings.filterwarnings("ignore") warnings.simplefilter(action='ignore', category=FutureWarning) @@ -59,55 +55,4 @@ def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_word token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre') predicted = model.predict_classes(token_list, verbose=2)[0] output += " " + word_indices[predicted] - return output.capitalize() - - -def main(): - # should_train = True - # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch - nb_words = 20 - nb_epoch = 100 - nb_layers = 128 - dropout = .2 - tokenizer = PoemTokenizer() - - # if should_train: - lines = load_kawa() - - corpus = [clean_text(x) for x in lines] - print("Corpus:", corpus[:10]) - - inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus) - predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words) - model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout) - model.summary() - - file_path = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) - checkpoint = ModelCheckpoint(file_path, monitor='accuracy', period=10, save_best_only=True) - # print_callback = LambdaCallback(on_epoch_end=on_epoch_end) - early_stopping = EarlyStopping(monitor='accuracy', patience=5) - callbacks_list = [checkpoint, early_stopping] - - for i in range(0, nb_epoch, 10): - model.fit(predictors, label, initial_epoch=i, epochs=min(i + 10, nb_epoch), verbose=2, callbacks=callbacks_list) - print(generate_text(model, tokenizer, "", nb_words, max_sequence_len)) - - # model.save(model_file) - # else: # FIXME: Load and predict, maybe reuse checkpoints? - # model = load_model(model_file) - - for i, seed in enumerate(load_seeds(lines, 5)): - output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len) - print("%i %s -> %s" % (i, seed, output)) - - with open("./output/dawa.txt", "a+") as f: - while True: - input_text = input("> ") - text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len) - - print(text) - f.writelines("%s\n" % text) - - -if __name__ == '__main__': - main() + return output.capitalize() \ No newline at end of file diff --git a/KoozDawa/dawa/tokens.py b/glossolalia/tokens.py similarity index 96% rename from KoozDawa/dawa/tokens.py rename to glossolalia/tokens.py index 1e05230..88d7c51 100644 --- a/KoozDawa/dawa/tokens.py +++ b/glossolalia/tokens.py @@ -1,6 +1,6 @@ from keras_preprocessing.text import Tokenizer -from KoozDawa.dawa.loader import load_kawa +from glossolalia.loader import load_kawa class PoemTokenizer(Tokenizer):