import os
import math
from warnings import warn

import numpy as np
from argparse import ArgumentParser
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

from plot_importance_group import dataset_name_dict

def histogram(dest, x, title="", xlabel="", ylabel=""):
    plt.figure()
    plt.hist(x, bins=int(x.shape[0] / 5), normed=True)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xlim((0, x.shape[0]))
    plt.savefig(dest)
    plt.close()


def bar(dest, x, width=1.0, xlabel="", ylabel="", title="bar", dpi=300, max_disp=500):
    position = np.arange(x.shape[0]) + 1
    n_features = x.shape[0]
    if n_features > max_disp:
        step = int(n_features / max_disp)
        data = x[::step]
        position = position[::step]
    else:
        data = x

    # positioning and limits
    plt.figure()
    vmin, vmax = np.min(x), np.max(x)
    plt.ylim((max(1e-6, vmin), vmax + (vmax - vmin) * 0.05))
    padding = int(n_features * 0.05)
    plt.xlim((-padding, n_features + padding))

    # actually plot
    plt.bar(position, data, width=width, align="edge")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(dest, dpi=dpi)
    plt.close()


def subplot_scatters(ref, cmps, ref_label, cmps_labels, dest, title="", dpi=300, logscale=False, figsize=(600, 400)):
    n_cmp = cmps.shape[0]
    grid_size = int(math.ceil(np.sqrt(n_cmp)))
    fig, axes = plt.subplots(
        nrows=grid_size - int((grid_size * grid_size - n_cmp) / grid_size),
        ncols=grid_size,
        squeeze=False
    )
    fig.suptitle(title)

    for i, cmp in enumerate(cmps):
        ax = axes[i / grid_size][i % grid_size]
        idx = np.ones(ref.shape, dtype=np.bool)
        if logscale:
            idx = np.logical_and(ref > 0, cmp > 0)
            ax.set_xscale("log")
            ax.set_yscale("log")
        overall_max = max(np.max(ref[idx]), np.max(cmp[idx]))
        overall_min = min(np.min(ref[idx]), np.min(cmp[idx]))
        ax.scatter(cmp[idx], ref[idx], alpha=0.5, linewidths=1)
        ax.set_xlabel(cmps_labels[i])
        ax.set_ylabel(ref_label)
        identity = np.linspace(overall_min, overall_max, num=3)
        ax.plot(identity, identity, "k--")

    plt.tight_layout(w_pad=1, h_pad=1, rect=[0, 0.03, 1, 0.95])
    plt.savefig(dest, dpi=dpi)
    plt.close()


def main(argv):
    """--path
D:\data\features\rfe\
--dest
D:\data\features\rfe\plots
--models
resnet50,inception_resnet_v2,inception_resnet_v2,inception_resnet_v2,inception_v3,mobile,vgg19,vgg16
--datasets
ulg_lbtd2_chimio_necrose,cells_no_aug,patterns_no_aug
"""
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path", dest="path")
    parser.add_argument("--models", dest="models")
    parser.add_argument("--datasets", dest="datasets")
    parser.add_argument("--dest", dest="dest", default=".")
    parser.add_argument("--scale", dest="scale", action="store_true")
    parser.add_argument("--best", dest="best", action="store_true")
    parser.set_defaults(scale=False, best=False)
    params, unknown = parser.parse_known_args(argv)
    params.models = params.models.split(",")
    params.datasets = params.datasets.split(",")
    print("Parameters: {}".format(params))

    for path in [params.dest]:
        if not os.path.exists(path):
            os.makedirs(path)

    print("> load data")
    model_dataset = {(model, dataset): "{}_image_net_{}".format(model, dataset) for model in params.models for dataset in params.datasets}
    filenames = {k: "{}_rfe.npz".format(name) for k, name in model_dataset.items()}
    data = {k: np.load(os.path.join(params.path, filename)) for k, filename in filenames.items() if os.path.isfile(os.path.join(params.path, filename))}
    metric_keys = ["accuracy", "precision", "recall", "f1-score", "roc_auc", "n_features"]

    # for model, dataset in data.keys():
    #     key = (model, dataset)
    #     if key not in data:
    #         warn("Missing data for {}".format(key))
    #         continue
    #     curr_data = data[key]
    #     print("> {}\t{}\t{}".format(model, dataset, "\t".join(["{:1.4f}".format(float(curr_data[metric])) for metric in metric_keys if metric in curr_data])))
    #     n_features = curr_data["support"].shape[0]
    #     n_evaluated = curr_data["scores"].shape[0]
    #     step = 4
    #     best_subset_sizes = np.argsort(curr_data["scores"])[::-1][:10]
    #     print("  {}".format(curr_data["scores"][best_subset_sizes]))
    #     print("  {}".format(n_features - best_subset_sizes * step))
    #     x = 1 + np.arange(n_evaluated) * step
    #     plt.figure(figsize=(6.4, 2.4))
    #     plt.plot(x, curr_data["scores"])
    #     # plt.title("{} - {} (features: {} / {})".format(model, dataset, curr_data["n_features"], n_features))
    #     plt.xlabel("n_features")
    #     plt.ylabel("accuracy")
    #     if params.scale:
    #         plt.ylim((0, 1))
    #     if params.best:
    #         plt.axvline(curr_data["n_features"], ymin=0, ymax=1, color="k", linestyle="--", alpha=0.5)
    #     plt.savefig(os.path.join(params.dest, "{}_{}_scores.png".format(model, dataset)), dpi=600)
    #     plt.close()


    datasets = [
        ("cells_no_aug", 241, "roc_auc", False, False),
        ("patterns_no_aug", 242, "roc_auc", False, False),
        ("glomeruli_no_aug", 243, "roc_auc", False, False),
        ("ulg_lbtd2_chimio_necrose", 244, "roc_auc", False, False),
        ("ulg_breast", 245, "roc_auc", True, False),
        ("ulg_lbtd_lba_new", 246, "accuracy", True, False),
        ("ulg_lbtd_tissus", 247, "accuracy", True, False),
        ("ulb_anapath_lba", 248, "accuracy", True, True)
    ]

    networks = [
        "mobile",
        "dense_net_201",
        "inception_resnet_v2",
        "resnet50",
        "inception_v3",
        "vgg19",
        "vgg16"
    ]

    for network in networks:
        fig = plt.figure(figsize=(12.6, 3.6))
        for dataset, subplot, metric, xlabel, legend in datasets:
            if dataset == "ulg_breast" or dataset == "ulg_lbtd_lba_new":
                step = 8
            else:
                step = 4
            key = (network, dataset)
            if key not in data:
                warn("Missing data for {}".format(key))
                continue
            scores = data[key]["scores"]
            best = data[key]["n_features"]
            n_evaluated = scores.shape[0]
            x = 1 + np.arange(n_evaluated) * step
            best = np.argmax(scores)
            lower_threshold = scores[best] - 0.005
            selected_n_features = np.where(scores > lower_threshold)[0][0]
            subplt = plt.subplot(subplot)
            plt.grid(linestyle='dotted', linewidth=1, color="k", alpha=0.3)
            subplt.set_axisbelow(True)
            plt.plot(x, scores)
            plt.axvline(x[best], color="k", linestyle="--", alpha=0.65, label="best")
            plt.axvline(x[selected_n_features], color="r", linestyle="--", alpha=0.80, label="selected")
            plt.ylabel(metric)
            plt.title(dataset_name_dict[dataset])

            # plt.ylim((0, 1))
            if xlabel:
                plt.xlabel("n_features")
            if legend:
                plt.legend()
        fig.tight_layout()
        # plt.suptitle("ResNet50 - RFE curves")
        plt.savefig(os.path.join(params.dest, "rfe_{}.png".format(network)), dpi=300)


if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
