diff --git a/KoozDawa/dawa/lstm.py b/KoozDawa/dawa/lstm.py index ed2c0f7..73e8863 100644 --- a/KoozDawa/dawa/lstm.py +++ b/KoozDawa/dawa/lstm.py @@ -41,10 +41,12 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1): # TOD model.compile(loss='categorical_crossentropy', optimizer='adam') + # TODO: Try alternative architectures + # https://medium.com/coinmonks/word-level-lstm-text-generator-creating-automatic-song-lyrics-with-neural-networks-b8a1617104fb#35f4 return model -def generate_text(model, tokenizer, seed_text="", nb_words=5, max_sequence_len=0): +def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_words=5, max_sequence_len=0) -> str: for _ in range(nb_words): token_list = tokenizer.texts_to_sequences([seed_text])[0] token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')