import numpy as np
import math
import os
from argparse import ArgumentParser

import re
from warnings import warn
from keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.preprocessing import LabelEncoder

from dataset import ImageClassificationDataset
from fine_tune import get_model
from get_features import ImageLoader
from util import print_cm, mk_task_filename, conservative_flow_from_directory


def find_weight_folder(path, model, dataset):
    dirs = [file for file in os.listdir(path) if os.path.isdir(os.path.join(path, file))]
    regex = r"^{}_{}_(.*)$".format(model, dataset)
    matches = [re.match(regex, _dir) for _dir in dirs]
    matching = [(m, d) for m, d in zip(matches, dirs) if m is not None]
    if len(matching) == 0:
        raise ValueError("No folder found for model and dataset: '{}' and '{}'".format(model, dataset))
    sorted_idx = sorted(range(len(matching)), key=lambda k: matching[k][0].group(1)) # argsort iso8601 timestamp
    return os.path.join(path, matching[sorted_idx[-1]][1])  # take older folder


def get_best_epoch_from_history(history):
    val_loss = history["scores"].item()["val_loss"]
    file_format = history["weights_file"].item()
    min_loss_idx = np.argmin(val_loss)
    min_epoch = history["epochs"].item()["indexes"][min_loss_idx] + 1
    min_loss = val_loss[min_loss_idx]
    return min_loss, min_epoch, file_format.format(epoch=min_epoch, val_loss=min_loss)


def main(argv):
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path", dest="path")
    parser.add_argument("--dest", dest="dest", default="ft")
    parser.add_argument("--model", dest="model")
    parser.add_argument("--dataset", dest="dataset")
    parser.add_argument("--test", dest="test", default="test")
    parser.add_argument("--n_jobs", dest="n_jobs", default=1, type=int)
    parser.add_argument("--cut", dest="cut", default=None)
    parser.add_argument("--freeze", dest="freeze", default=None)
    parser.add_argument("--weights_folders_dir", dest="weights_folders_dir")
    parser.add_argument("--testing_batch_size", dest="batch_size", type=int, default=128)
    parser.add_argument("--extract", dest="extract", action="store_true")
    parser.set_defaults(extract=False)
    params, unknown = parser.parse_known_args(argv)
    print("Parameters: {}".format(params))

    # create useful directories
    if not os.path.exists(params.dest):
        os.makedirs(params.dest)

    # load dataset data
    dataset_path = os.path.join(params.path, params.dataset)
    print("Load dataset from '{}'".format(dataset_path))
    dataset = ImageClassificationDataset(dataset_path, dirs=[params.test])

    # n_classes
    _x_test, _y_test = dataset.all(dirs=[params.test], one_hot=False)
    n_samples = _x_test.shape[0]
    n_classes = np.unique(_y_test).shape[0]
    print("Loaded data")
    print("> x_test  : {}".format(_x_test.shape))
    print("> y_test  : {}".format(_y_test.shape))
    print("> n_class : {}".format(n_classes))

    print("Infer weight file...")
    weights_folder = find_weight_folder(params.weights_folders_dir, params.model, params.dataset)
    history = np.load(os.path.join(weights_folder, "history.npz"))
    min_val_loss, min_epoch, weights_file = get_best_epoch_from_history(history)
    weights_path = os.path.join(weights_folder, "weights", weights_file)

    print("Best epoch:")
    print("> epoch: {}".format(min_epoch))
    print("> loss : {:0.4f}".format(min_val_loss))
    print("> file : {}".format(weights_file))
    print("> path : {}".format(weights_path))

    if min_epoch == history["epochs"].item()["indexes"][-1] + 1:
        warn("Best epoch is the last one with loss {:0.4f}.".format(min_val_loss))

    # load model
    print("Load model...")
    (height_in, width_in, chan_in), preproc, model = get_model(
        params.model,
        n_classes,
        cut=params.cut,
        freeze=params.freeze
    )
    model.load_weights(weights_path, by_name=True)

    # prepare inference
    print("Inference...")
    test_datagen = ImageDataGenerator(preprocessing_function=preproc)
    test_directory = os.path.join(dataset_path, params.test)
    print(" -> test dataset: {}".format(test_directory))

    test_flow = conservative_flow_from_directory(
        test_datagen,
        test_directory,
        load_size_range=(height_in, height_in + 1),
        target_size=height_in if height_in == width_in else (height_in, width_in),
        batch_size=params.batch_size,
        seed=None,
        shuffle=False,
        follow_links=True
    )
    idx_to_cls = {i: c for c, i in test_flow.class_indices.items()}
    classes = np.array([int(idx_to_cls[i]) for i in range(n_classes)])
    ratio = int(n_samples / params.batch_size)
    n_batches = int(ratio) if n_samples % params.batch_size == 0 else (int(ratio) + 1)

    probas_dim = [n_samples, n_classes]
    probas = np.zeros(probas_dim, dtype=np.float)
    y_test = np.zeros([n_samples], dtype=np.int)
    y_test_encoded = np.zeros([n_samples], dtype=np.int)
    for batch_idx in range(n_batches):
        start = batch_idx * params.batch_size
        end = min(start + params.batch_size, n_samples)
        x, _y = next(test_flow)
        probas[start:end] = model.predict(x)
        y_test_encoded[start:end] = np.argmax(_y, axis=1)
        y_test[start:end] = np.take(classes, y_test_encoded[start:end])

    print("Predict test...")
    if n_classes == 2:
        print("> binary problem")
        thresholds = probas[:, 1]
        y_pred = np.argmax(probas, axis=1).astype(np.int)
        print("> decision_function : {}".format(thresholds.shape))
    else:
        print("> multi-class problem")
        y_pred = np.argmax(probas, axis=1).astype(np.int)
    y_pred = np.take(classes, y_pred)
    print("> y_pred            : {}".format(y_pred.shape))

    label_params = {"labels": classes, "pos_label": classes[1]}
    to_return = {
        "accuracy": accuracy_score(y_test, y_pred),
        "cm": confusion_matrix(y_test, y_pred)
    }

    print("Scores:")
    print("> accuracy : {}".format(to_return["accuracy"]))
    if n_classes == 2:
        to_return["precision"] = precision_score(y_test, y_pred, **label_params)
        to_return["recall"] = recall_score(y_test, y_pred, **label_params)
        to_return["f1-score"] = f1_score(y_test, y_pred, **label_params)
        to_return["roc_auc"] = roc_auc_score(y_test_encoded, thresholds)
        print("> precision: {}".format(to_return["precision"]))
        print("> recall   : {}".format(to_return["recall"]))
        print("> f1-score : {}".format(to_return["f1-score"]))
        print("> roc auc  : {}".format(to_return["roc_auc"]))

    print_cm(to_return["cm"], [str(i) for i in classes])

    # save results
    task = "ftnet_inference"
    filename = mk_task_filename(
        task=task,
        model=params.model,
        layer=None,
        dataset=params.dataset,
        folder=None,
        reduction=None,
        source="img_net_fine_tuned"
    )
    filepath = os.path.join(params.dest, filename)
    print("> save results for task '{}' to '{}'".format(task, filepath))
    np.savez(filepath, **to_return)

    # extract features
    if params.extract:
        dataset = params.dataset if not params.dataset.endswith("_val") else params.dataset[:-4]
        folders = "train,val" if dataset in {"ulg_lbtd_tissus", "ulg_lbtd_lba", "ulb_anapath_lba", "ulg_lbtd2_chimio_necrose"} else "train,test"
        extract_params = [
            "--path", params.path,
            "--model", params.model,
            "--dataset", dataset,
            "--folders", folders,
            "--source", "img_net_fine_tuned",
            "--weights", weights_path,
            "--dest", os.path.join(params.path, "features", "fine_tuned_nodata")
        ]
        from get_features import main as get_features_main
        get_features_main(extract_params)


if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
