From 8210fef0915bdade46373c9e6911b9a1711c07f2 Mon Sep 17 00:00:00 2001 From: Paul-Louis NECH Date: Wed, 27 Nov 2019 10:14:21 +0100 Subject: [PATCH] feat(tokens): Lowercase made optional --- KoozDawa/dawa.py | 6 ++++-- LeBoulbiNet/boulbi.py | 8 +++++--- glossolalia/cleaner.py | 1 + glossolalia/lstm.py | 1 + glossolalia/tokens.py | 8 ++++---- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/KoozDawa/dawa.py b/KoozDawa/dawa.py index 76bbef9..9c9e23a 100644 --- a/KoozDawa/dawa.py +++ b/KoozDawa/dawa.py @@ -8,13 +8,13 @@ from glossolalia.lstm import LisSansTaMaman def train(): # should_train = True - # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch nb_words = 20 nb_epoch = 100 nb_layers = 100 dropout = .3 # TODO finetune layers/dropout validation_split = 0.2 lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True) + filename_model = "../models/dawa/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch) filename_output = "./output/dawa_%i-d%.1f_%s.txt" % ( nb_layers, dropout, datetime.now().strftime("%y%m%d_%H%M")) @@ -31,7 +31,9 @@ def train(): callbacks=callbacks_list, validation_split=validation_split) - print(lstm.predict_seeds(nb_words)) + for output in lstm.predict_seeds(nb_words): + print(output) + f.writelines(output) for i, seed in enumerate(load_seeds(corpus, 5)): output = lstm.predict(seed, nb_words) diff --git a/LeBoulbiNet/boulbi.py b/LeBoulbiNet/boulbi.py index d9e4e2a..7468558 100644 --- a/LeBoulbiNet/boulbi.py +++ b/LeBoulbiNet/boulbi.py @@ -11,10 +11,10 @@ def train(): nb_words = 20 nb_epoch = 50 nb_layers = 64 - dropout = .2 - # TODO finetune layers/dropout + dropout = .2 # TODO finetune layers/dropout validation_split = 0.2 lstm = LisSansTaMaman(nb_layers, dropout, validation_split, debug=True) +# filename_model = "../models/boulbi/boulbi_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % ( nb_layers, dropout, nb_epoch) filename_output = "./output/boulbi_%i-d%.1f_%s.txt" % ( @@ -32,7 +32,9 @@ def train(): callbacks=callbacks_list, validation_split=validation_split) - print(lstm.predict_seeds(nb_words)) + for output in lstm.predict_seeds(nb_words): + print(output) + f.writelines(output) for i, seed in enumerate(load_seeds(corpus, 5)): output = lstm.predict(seed, nb_words) diff --git a/glossolalia/cleaner.py b/glossolalia/cleaner.py index 4abb5cc..38686fe 100644 --- a/glossolalia/cleaner.py +++ b/glossolalia/cleaner.py @@ -2,6 +2,7 @@ from glossolalia import loader def clean(text): + # TODO: Remove lines with ??? # Replace literal newlines # Remove empty lines # Replace ’ by ' diff --git a/glossolalia/lstm.py b/glossolalia/lstm.py index a695627..a807689 100644 --- a/glossolalia/lstm.py +++ b/glossolalia/lstm.py @@ -44,6 +44,7 @@ class LisSansTaMaman(object): model.summary() self.model = model + print("Max sequence length:", self.max_sequence_len) # TODO: Batch fit? splitting nb_epoch into N step def fit(self, epochs: int, initial_epoch: int = 0, diff --git a/glossolalia/tokens.py b/glossolalia/tokens.py index 4e10861..916059f 100644 --- a/glossolalia/tokens.py +++ b/glossolalia/tokens.py @@ -4,10 +4,10 @@ from glossolalia.loader import load_texts class PoemTokenizer(Tokenizer): - def __init__(self, **kwargs) -> None: - super().__init__(lower=True, # TODO: Better generalization without? - filters='$%&()*+/<=>@[\\]^_`{|}~\t\n', oov_token="😢", - **kwargs) + def __init__(self, lower:bool = True, **kwargs) -> None: + super().__init__(lower=lower, # TODO: Better generalization without? + filters='$%&*+/<=>@[\\]^_`{|}~\t\n', oov_token="😢", + **kwargs) #TODO: keep newlines def get_sequence_of_tokens(self, corpus): self.fit_on_texts(corpus) -- libgit2 0.27.0