diff --git a/KoozDawa/dawa/lstm.py b/KoozDawa/dawa/lstm.py
index 73e8863..c553801 100644
--- a/KoozDawa/dawa/lstm.py
+++ b/KoozDawa/dawa/lstm.py
@@ -2,13 +2,14 @@ import warnings
 
 import numpy as np
 from keras import Sequential
+from keras.callbacks import ModelCheckpoint, EarlyStopping
 from keras.layers import Embedding, LSTM, Dropout, Dense
 from keras.utils import to_categorical
 from keras_preprocessing.sequence import pad_sequences
 from keras_preprocessing.text import Tokenizer
 
 from KoozDawa.dawa.loader import load_kawa, clean_text, load_seeds
-from KoozDawa.dawa.tokens import get_sequence_of_tokens
+from KoozDawa.dawa.tokens import PoemTokenizer
 
 warnings.filterwarnings("ignore")
 warnings.simplefilter(action='ignore', category=FutureWarning)
@@ -25,7 +26,7 @@ def generate_padded_sequences(input_sequences, total_words):
     return predictors, label, max_sequence_len
 
 
-def create_model(max_sequence_len, total_words, layers=100, dropout=0.1):  # TODO finetune
+def create_model(max_sequence_len, total_words, layers=128, dropout=0.2):  # TODO finetune layers/dropout
     input_len = max_sequence_len - 1
     model = Sequential()
 
@@ -39,7 +40,9 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1):  # TOD
     # Add Output Layer
     model.add(Dense(total_words, activation='softmax'))
 
-    model.compile(loss='categorical_crossentropy', optimizer='adam')
+    model.compile(optimizer='adam',  # TODO: Try RMSprop(learning_rate=0.01)
+                  loss='categorical_crossentropy',  # TODO: Try sparse_categorical_crossentropy for faster training
+                  metrics=['accuracy'])
 
     # TODO: Try alternative architectures
     #  https://medium.com/coinmonks/word-level-lstm-text-generator-creating-automatic-song-lyrics-with-neural-networks-b8a1617104fb#35f4
@@ -47,53 +50,61 @@ def create_model(max_sequence_len, total_words, layers=100, dropout=0.1):  # TOD
 
 
 def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_words=5, max_sequence_len=0) -> str:
+    word_indices = {v: k for k, v in tokenizer.word_index.items()}
+    output = seed_text
+
     for _ in range(nb_words):
-        token_list = tokenizer.texts_to_sequences([seed_text])[0]
+        token_list = tokenizer.texts_to_sequences([output])[0]
         token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
-        predicted = model.predict_classes(token_list, verbose=2)
-
-        output_word = ""
-        for word, index in tokenizer.word_index.items():
-            if index == predicted:
-                output_word = word
-                break
-        seed_text += " " + output_word
-    return seed_text.capitalize()
+        predicted = model.predict_classes(token_list, verbose=2)[0]
+        output += " " + word_indices[predicted]
+    return output.capitalize()
 
 
 def main():
-    should_train = True
+    # should_train = True
     # model_file = "../models/dawa_lstm_%i.hd5" % nb_epoch
+    nb_words = 20
     nb_epoch = 100
-    nb_words = 200
-    tokenizer = Tokenizer()
+    nb_layers = 128
+    dropout = .2
+    tokenizer = PoemTokenizer()
 
     # if should_train:
     lines = load_kawa()
 
     corpus = [clean_text(x) for x in lines]
-    print("Corpus:", corpus[:5])
+    print("Corpus:", corpus[:10])
 
-    inp_sequences, total_words = get_sequence_of_tokens(corpus, tokenizer)
+    inp_sequences, total_words = tokenizer.get_sequence_of_tokens(corpus)
     predictors, label, max_sequence_len = generate_padded_sequences(inp_sequences, total_words)
-    model = create_model(max_sequence_len, total_words)
+    model = create_model(max_sequence_len, total_words, layers=nb_layers, dropout=dropout)
     model.summary()
 
-    model.fit(predictors, label, epochs=nb_epoch, verbose=5)
+    file_path = "../models/dawa_lstm%i-d%.1f-{epoch:02d}_%i-{accuracy:.4f}.hdf5" % (nb_layers, dropout, nb_epoch)
+    checkpoint = ModelCheckpoint(file_path, monitor='accuracy', save_best_only=True)
+    # print_callback = LambdaCallback(on_epoch_end=on_epoch_end)
+    early_stopping = EarlyStopping(monitor='accuracy', patience=5)
+    callbacks_list = [checkpoint, early_stopping]
+
+    for i in range(nb_epoch):
+        model.fit(predictors, label, initial_epoch=i, epochs=i + 1, verbose=2, callbacks=callbacks_list)
+
     # model.save(model_file)
-    # else: # FIXME: Load and predict
+    # else: # FIXME: Load and predict, maybe reuse checkpoints?
     # model = load_model(model_file)
 
-    for sample in load_seeds(lines):
-        print(generate_text(model, tokenizer, sample, nb_words, max_sequence_len))
+    for i, seed in enumerate(load_seeds(lines, 3)):
+        output = generate_text(model, tokenizer, seed, nb_words, max_sequence_len)
+        print("%i %s -> %s" % (i, seed, output))
 
-    with open("./output/lstm.txt", "a+") as f:
+    with open("./output/dawa.txt", "a+") as f:
         while True:
             input_text = input("> ")
             text = generate_text(model, tokenizer, input_text, nb_words, max_sequence_len)
 
             print(text)
-            f.writelines(text)
+            f.writelines("%s\n" % text)
 
 
 if __name__ == '__main__':