import warnings

import numpy as np
from keras import Sequential
from keras.layers import Embedding, LSTM, Dropout, Dense
from keras.utils import to_categorical
from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer

warnings.filterwarnings("ignore")
warnings.simplefilter(action='ignore', category=FutureWarning)


# 3.3 Padding the Sequences and obtain Variables : Predictors and Target
def generate_padded_sequences(input_sequences, total_words):
    max_sequence_len = max([len(x) for x in input_sequences])
    input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))
    predictors, label = input_sequences[:, :-1], input_sequences[:, -1]
    label = to_categorical(label, num_classes=total_words)
    return predictors, label, max_sequence_len


def create_model(max_sequence_len, total_words, layers=128, dropout=0.3):  # TODO finetune layers/dropout
    print("Creating model across %i words for %i-long seqs (%i layers, %.2f dropout):" %
          (total_words, max_sequence_len, layers, dropout))
    input_len = max_sequence_len - 1
    model = Sequential()

    # Add Input Embedding Layer
    model.add(Embedding(total_words, 10, input_length=input_len))

    # Add Hidden Layer 1 - LSTM Layer
    model.add(LSTM(layers))
    # model.add(Bidirectional(LSTM(layers), input_shape=(max_sequence_len, total_words)))
    model.add(Dropout(dropout))

    # Add Output Layer
    model.add(Dense(total_words, activation='softmax'))

    model.compile(optimizer='adam',  # TODO: Try RMSprop(learning_rate=0.01)
                  loss='categorical_crossentropy',  # TODO: Try sparse_categorical_crossentropy for faster training
                  metrics=['accuracy'])

    # TODO: Try alternative architectures
    #  https://medium.com/coinmonks/word-level-lstm-text-generator-creating-automatic-song-lyrics-with-neural-networks-b8a1617104fb#35f4
    return model


def generate_text(model: Sequential, tokenizer: Tokenizer, seed_text="", nb_words=5, max_sequence_len=0) -> str:
    word_indices = {v: k for k, v in tokenizer.word_index.items()}
    output = seed_text

    for _ in range(nb_words):
        token_list = tokenizer.texts_to_sequences([output])[0]
        token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
        predicted = model.predict_classes(token_list, verbose=2)[0]
        output += " " + word_indices[predicted]
    return output.capitalize()