import datetime
from argparse import ArgumentParser

import os
import numpy as np
import re
from keras import Model
from keras.applications.densenet import DenseNet201, preprocess_input as preproc_dense_net_201
from keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input as preproc_inception_resnet_v2
from keras.applications.resnet50 import ResNet50, preprocess_input as preproc_resnet50
from keras.callbacks import TerminateOnNaN, ModelCheckpoint
from keras.layers import GlobalAveragePooling2D, Dense, Activation
from keras.optimizers import adam
from keras.preprocessing.image import ImageDataGenerator

from dataset import ImageClassificationDataset
from util import conservative_flow_from_directory


def freeze_up_to(model, layer_name=None, layer_index=None, all_trainable=False):
    if not all_trainable and layer_index is None and layer_name is None:
        raise ValueError("You should select at least a name or an index.")

    reached = False
    for i, layer in enumerate(model.layers):
        if reached or all_trainable:
            layer.trainable = True
        else:
            reached = (layer_name is not None and layer.name == layer_name) or \
                        (layer_index is not None and i == layer_index)
            layer.trainable = False


def get_model(model_name, n_classes, cut=None, freeze=None, avg=True):
    """Return model with given name
    cut: str
        Name of the layer after which to cut (None for taking the whole network)
    freeze: str
        Name of the layer up to which the weights must be frozen
    """
    if model_name == "resnet50":
        in_shape = (224, 224, 3)
        model = ResNet50(weights="imagenet", input_shape=in_shape, include_top=False, pooling=None)
        preproc = preproc_resnet50
    elif model_name == "inception_resnet_v2":
        in_shape = (299, 299, 3)
        model = InceptionResNetV2(weights="imagenet", input_shape=in_shape, include_top=False, pooling=None)
        preproc = preproc_inception_resnet_v2
    elif model_name == "dense_net_201":
        in_shape = (224, 224, 3)
        model = DenseNet201(weights="imagenet", input_shape=in_shape, include_top=False, pooling=None)
        preproc = preproc_dense_net_201
    else:
        raise ValueError("Unknown model '{}'".format(model_name))
    layer_dict = {l.name: l for l in model.layers}

    # freeze up to a given layer
    if freeze is not None:
        if freeze not in layer_dict:
            raise ValueError("Unknown layer '{}' (to freeze)".format(freeze))
        freeze_up_to(model, layer_name=freeze)

    # cut after the given layer
    if cut is not None:
        if freeze not in layer_dict:
            raise ValueError("Unknown layer '{}' (to cut)".format(freeze))
        model = Model(inputs=model.inputs, outputs=[layer_dict[cut].output])

    # add averaging
    if avg:
        glb_avg = GlobalAveragePooling2D(name="global_avg_pooling")(model.layers[-1].output)
        model = Model(inputs=model.inputs, outputs=[glb_avg])

    fc = Dense(n_classes, name='fc')(model.output)
    softmax = Activation("softmax")(fc)
    return in_shape, preproc, Model(inputs=model.inputs, outputs=[softmax])


def custom_iso(clean=''):
    return re.sub('[:-]', clean, datetime.datetime.now().isoformat())


def main(argv):
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path", dest="path")
    parser.add_argument("--model", dest="model")
    parser.add_argument("--dataset", dest="dataset")
    parser.add_argument("--train", dest="train", default="train")
    parser.add_argument("--val", dest="val", default="val")
    parser.add_argument("--test", dest="test", default="test")
    parser.add_argument("--cut", dest="cut", default=None)
    parser.add_argument("--freeze", dest="freeze", default=None)
    parser.add_argument("--working_path", dest="working_path")
    parser.add_argument("--n_jobs", dest="n_jobs", type=int, default=1)
    parser.add_argument("--augment", dest="augment", action="store_true")
    parser.add_argument("--training_dense_lr", dest="dense_lr", type=float, default=0.01)
    parser.add_argument("--training_end_lr", dest="end_lr", type=float, default=0.0001)
    parser.add_argument("--training_dense_epochs", dest="dense_epochs", type=int, default=2)
    parser.add_argument("--training_end_epochs", dest="end_epochs", type=int, default=23)
    parser.add_argument("--training_seed", dest="seed", type=int, default=42)
    parser.add_argument("--training_batch_size", dest="batch_size", type=int, default=64)
    parser.add_argument("--training_beta_1", dest="beta_1", type=float, default=0.9)
    parser.add_argument("--training_beta_2", dest="beta_2", type=float, default=0.999)
    parser.add_argument("--training_decay", dest="decay", type=float, default=None)
    parser.set_defaults(augment=False)
    params, unknown = parser.parse_known_args(argv)
    print("Parameters: {}".format(params))

    # create useful directories
    working_dir = os.path.join(params.working_path, "{}_{}_{}".format(params.model, params.dataset, custom_iso()))
    weight_dir = os.path.join(working_dir, "weights")
    if not os.path.exists(weight_dir):
        os.makedirs(weight_dir)

    # load dataset data
    dataset_path = os.path.join(params.path, params.dataset)
    dataset = ImageClassificationDataset(dataset_path, dirs=[params.train, params.val, params.test])

    # n_classes
    _, y = dataset.all()
    classes = np.unique(y)
    n_classes = classes.shape[0]
    x_train, x_val = dataset.all(dirs=[params.train])[0], dataset.all(dirs=[params.val])[0]
    n_samples_train, n_samples_val = len(x_train), len(x_val)

    print("> dataset:")
    print(">  - n_classes: {}".format(n_classes))
    print(">  - classes  : {}".format(classes))
    print(">  - n_samples: ({}, {})".format(n_samples_train, n_samples_val))

    # load model
    in_shape, preproc, model = get_model(params.model, n_classes, cut=params.cut, freeze=params.freeze)

    # prepare data access objects
    print("Prepare dataset generation...")
    if params.augment:
        train_datagen = ImageDataGenerator(
            rotation_range=180.0,
            horizontal_flip=True,
            vertical_flip=True,
            preprocessing_function=preproc
        )
    else:  # no augmentation
        train_datagen = ImageDataGenerator(preprocessing_function=preproc)
    val_datagen = ImageDataGenerator(preprocessing_function=preproc)

    train_directory = os.path.join(dataset_path, params.train)
    print(" -> training dataset: {}".format(train_directory))
    val_directory = os.path.join(dataset_path, params.val)
    print(" -> val dataset: {}".format(val_directory))

    train_generator = conservative_flow_from_directory(
        train_datagen,
        train_directory,
        load_size_range=(in_shape[0], in_shape[0] + 1),
        target_size=in_shape[0] if in_shape[0] == in_shape[1] else in_shape[:2],
        batch_size=params.batch_size,
        seed=params.seed,
        shuffle=True,
        random_crop=params.augment,
        follow_links=True
    )

    validation_gen = conservative_flow_from_directory(
        val_datagen,
        val_directory,
        load_size_range=(in_shape[0], in_shape[0] + 1),
        target_size=in_shape[0] if in_shape[0] == in_shape[1] else in_shape[:2],
        batch_size=params.batch_size,
        seed=None,
        shuffle=False,
        random_crop=False,
        follow_links=True
    )

    # register training callbacks
    callbacks = [
        TerminateOnNaN(),
        ModelCheckpoint(os.path.join(weight_dir, "weights.{epoch:03d}-{val_loss:.5f}.hdf5")),
    ]

    print("> train 1st round: ")
    print(">  - compile")
    freeze_up_to(model, layer_name="global_avg_pooling")
    model.compile(
        optimizer=adam(lr=params.dense_lr),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )

    print(">  - summary")
    model.summary()

    # train
    print(">  - fit")
    history_last = model.fit_generator(
        train_generator,
        epochs=params.dense_epochs,
        steps_per_epoch=int(n_samples_train / params.batch_size) + 1,
        validation_data=validation_gen,
        validation_steps=int(n_samples_val / params.batch_size) + 1,
        workers=params.n_jobs,
        callbacks=callbacks,
        verbose=1
    )

    print("> train 2nd round: ")
    print(">  - compile")
    freeze_up_to(model, layer_name=params.freeze, all_trainable=params.freeze is None)
    model.compile(
        optimizer=adam(lr=params.end_lr),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )

    print(">  - summary")
    model.summary()

    # train
    print(">  - fit")
    history_whole = model.fit_generator(
        train_generator,
        epochs=params.end_epochs + params.dense_epochs,
        initial_epoch=params.dense_epochs,
        steps_per_epoch=int(n_samples_train / params.batch_size) + 1,
        validation_data=validation_gen,
        validation_steps=int(n_samples_val / params.batch_size) + 1,
        workers=params.n_jobs,
        callbacks=callbacks,
        verbose=1
    )

    save_dict = {
        "working_path": working_dir,
        "epochs": {
            "indexes": history_last.epoch + history_whole.epoch,
            "dense": params.dense_epochs,
            "full": params.end_epochs
        },
        "scores": {
            "acc": history_last.history.get("acc", []) + history_whole.history.get("acc", []),
            "loss": history_last.history.get("loss", []) + history_whole.history.get("loss", []),
            "val_acc": history_last.history.get("val_acc", []) + history_whole.history.get("val_acc", []),
            "val_loss": history_last.history.get("val_loss", []) + history_whole.history.get("val_loss", []),
        },
        "params": {
            "last": history_last.params,
            "whole": history_whole.params
        },
        "weights_file": "weights.{epoch:03d}-{val_loss:.5f}.hdf5"
    }
    np.savez(os.path.join(working_dir, "history.npz"), **save_dict)



if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
