

import os
import math
import numpy as np
from argparse import ArgumentParser
import matplotlib
import re

from natsort import natsort, natsorted

from per_layer_to_plot import get_layer_order, LAYER_NAME_DICT
from util import mk_task_filename

matplotlib.use('Agg')
from matplotlib import pyplot as plt

dataset_name_dict = {
    "glomeruli_no_aug": "Glomeruli",
    "patterns_no_aug": "ProliferativePatterns",
    "cells_no_aug": "CellInclusion",
    "ulg_lbtd2_chimio_necrose": "Necrosis",
    "ulg_breast": "Breast",
    "ulg_lbtd_tissus": "Lung",
    "ulg_lbtd_lba": "xxx",
    "ulg_lbtd_lba_new": "MouseLba",
    "ulb_anapath_lba": "HumanLba",
}

network_name_dict = {
    "mobile": "Mobile",
    "dense_net_201": "DenseNet",
    "inception_resnet_v2": "IncResV2",
    "resnet50": "ResNet",
    "inception_v3": "IncV3",
    "vgg19": "VGG19",
    "vgg16": "VGG16",
}


def histogram(dest, x, title="", xlabel="", ylabel=""):
    plt.figure()
    plt.hist(x, bins=int(x.shape[0] / 5), normed=True)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xlim((0, x.shape[0]))
    plt.savefig(dest)
    plt.close()


def bar(dest, x, width=1.0, xlabel="", ylabel="", title="bar", dpi=300, max_disp=500):
    position = np.arange(x.shape[0]) + 1
    n_features = x.shape[0]
    if n_features > max_disp:
        step = int(n_features / max_disp)
        data = x[::step]
        position = position[::step]
    else:
        data = x

    # positioning and limits
    plt.figure()
    vmin, vmax = np.min(x), np.max(x)
    plt.ylim((max(1e-6, vmin), vmax + (vmax - vmin) * 0.05))
    padding = int(n_features * 0.05)
    plt.xlim((-padding, n_features + padding))

    # actually plot
    plt.bar(position, data, width=width, align="edge")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(dest, dpi=dpi)
    plt.close()


def subplot_scatters(ref, cmps, ref_label, cmps_labels, dest, title="", dpi=300, logscale=False, figsize=(600, 400)):
    n_cmp = cmps.shape[0]
    grid_size = int(math.ceil(np.sqrt(n_cmp)))
    fig, axes = plt.subplots(
        nrows=grid_size - int((grid_size * grid_size - n_cmp) / grid_size),
        ncols=grid_size,
        squeeze=False
    )
    fig.suptitle(title)

    for i, cmp in enumerate(cmps):
        ax = axes[i / grid_size][i % grid_size]
        idx = np.ones(ref.shape, dtype=np.bool)
        if logscale:
            idx = np.logical_and(ref > 0, cmp > 0)
            ax.set_xscale("log")
            ax.set_yscale("log")
        overall_max = max(np.max(ref[idx]), np.max(cmp[idx]))
        overall_min = min(np.min(ref[idx]), np.min(cmp[idx]))
        ax.scatter(cmp[idx], ref[idx], alpha=0.5, linewidths=1)
        ax.set_xlabel(cmps_labels[i])
        ax.set_ylabel(ref_label)
        identity = np.linspace(overall_min, overall_max, num=3)
        ax.plot(identity, identity, "k--")

    plt.tight_layout(w_pad=1, h_pad=1, rect=[0, 0.03, 1, 0.95])
    plt.savefig(dest, dpi=dpi)
    plt.close()


def group_importances(importances, groups=None):
    """
    importances:
    groups:
    """
    if groups is None:
        groups = [0.1, 0.25, 0.5, 1.0]
    idx = np.argsort(importances)[::-1]
    cumulated = np.cumsum(importances[idx])
    nb_features = np.searchsorted(cumulated, groups, side="right")
    grouped = np.zeros((len(groups), importances.shape[0]), dtype=importances.dtype)
    prev_nb_features = 0
    for i, group in enumerate(groups):
        selected = idx[prev_nb_features:nb_features[i]]
        grouped[i, selected] = importances[selected]
        prev_nb_features = nb_features[i]
    return grouped, groups


def add_prop_bar(x, container, group_labels, unique_labels):
    rec_witdh = container[0].get_width()
    labels, counts = np.unique(group_labels, return_counts=True)
    # make sure labels have the same order as the one plotted
    count_dict = {lab: cnt for lab, cnt in zip(labels, counts)}
    ordered_counts = np.array([count_dict[lab] for lab in unique_labels], dtype=np.float)
    plt.bar(x, ordered_counts / np.sum(ordered_counts), width=rec_witdh / 5, color="black", label="features")


def simple_bar_plot(importances, features, along, dataset, task, model=None):
    group_labels = [next(filter(lambda t: t[0] == along, l))[1] for l, i in features]
    unique_labels_dict = {l: k for k, l in enumerate(set(group_labels))}
    if along == "layer":
        plot_title = "{} - {} - {} - total imp.".format(dataset, task, model)
        idx = get_layer_order(list(unique_labels_dict.keys()), model)
        unique_labels = np.array(list(unique_labels_dict.keys()))[idx]
    else:
        plot_title = "{} - {} - total imp.".format(dataset, task)
        unique_labels = natsorted(unique_labels_dict.keys())
    plt.figure()
    x = range(len(unique_labels))
    plt.xticks(x, unique_labels, rotation=55, ha="right")
    container = plt.bar(x, importances, label="importances")
    add_prop_bar(x, container, group_labels=group_labels, unique_labels=unique_labels)
    plt.title(plot_title)
    plt.legend()
    plt.tight_layout()


def group_bar_plot(importances, groups, unique_groups, ylabel="", percentages=None, disp_features=True, legend_kw=None, fig_kw=None, ncol=1, grid=False):
    if fig_kw is None:
        fig_kw = dict()
    if legend_kw is None:
        legend_kw = dict()
    unique_groups_dict = {l: k for k, l in enumerate(set(groups))}
    imp_groups, percentages = group_importances(importances, groups=percentages)
    aggr_shape = (imp_groups.shape[0], len(unique_groups_dict))
    sum_aggr = np.zeros(aggr_shape)
    for i, label in enumerate(unique_groups):
        imp = imp_groups[:, groups == label]
        sum_aggr[:, i] = np.sum(imp, axis=1)

    fig = plt.figure(**fig_kw)
    if grid:
        plt.grid(linestyle='dotted', linewidth=1, color="k", alpha=0.3)
        fig.get_axes()[0].set_axisbelow(True)
    x = range(len(unique_groups))
    container = None
    for i, perc in enumerate(percentages):
        container = plt.bar(x, sum_aggr[i, :], bottom=np.sum(sum_aggr[:i, :], axis=0),
                            label="{:3d}%".format(int(perc * 100)))

    if container is not None and disp_features:
        add_prop_bar(x, container, group_labels=groups, unique_labels=unique_groups)

    plt.xticks(x, unique_groups, rotation=55, ha="right")
    legend_kw["ncol"] = ncol
    legend = plt.legend(**legend_kw)
    plt.ylabel(ylabel)
    plt.tight_layout()
    return legend


def print_best_source(importances, groups, n=10, perc=None):
    print("> best:")
    unique_groups = np.unique(groups)
    feature_cnt_offset = {grp: np.where(grp == groups)[0][0] for grp in unique_groups}
    idx = np.argsort(importances)[::-1]
    if n is None and perc is not None:
        cumsum = np.cumsum(importances[idx])
        n = np.searchsorted(cumsum, [0.1], side="right")[0] + 1
    for i in range(n):
        print("> {:2.2f} % - {:2.2f} % - from {} ({})".format(
            100 * importances[idx][i],
            100 * np.sum(importances[idx][:(i+1)]),
            groups[idx][i],
            idx[i] - feature_cnt_offset[groups[idx][i]]
        ))
    print("Appearing in top {}:".format(n))
    uniqs, cnts = np.unique(groups[idx][:n], return_counts=True)
    for uni, cnt in zip(uniqs, cnts):
        print("> {: >19} : {}".format(uni, cnt))


def main(argv):
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path", dest="path")
    parser.add_argument("--task", dest="task")
    parser.add_argument("--model", dest="model")
    parser.add_argument("--datasets", dest="datasets")
    parser.add_argument("--layer", dest="layer", default=None)
    parser.add_argument("--source", dest="source", default=None)
    parser.add_argument("--reduction", dest="reduction", default=None)
    parser.add_argument("--along", dest="along")
    parser.add_argument("--dest", dest="dest", default=".")
    parser.add_argument("--disp_features", dest="disp_features", action="store_true")
    parser.set_defaults(disp_features=False)
    params, unknown = parser.parse_known_args(argv)
    params.datasets = params.datasets.split(",")
    print("Parameters: {}".format(params))

    dest = os.path.join(params.path, params.dest)
    if not os.path.exists(dest):
        os.makedirs(dest)

    filenames = [
        mk_task_filename(
            task=params.task + "_et",
            model=params.model,
            layer=params.layer,
            dataset=dataset,
            folder=None,
            reduction=params.reduction,
            source=params.source
        ) for dataset in params.datasets
    ]

    all_data = [np.load(os.path.join(params.path, filename)) for filename in filenames]

    # average
    features = np.array(all_data[0]["features"])
    groups = np.array([next(filter(lambda t: t[0] == params.along, l))[1] for l, i in features])
    groups = np.array([LAYER_NAME_DICT[e] if params.along == "layer" else network_name_dict[e] for e in groups])
    unique_groups = np.unique(groups)
    all_importances = np.mean([data["importances"] for data in all_data], axis=0)
    avg_aggr_imp = np.array([np.sum(all_importances[label == groups]) for label in unique_groups])
    simple_bar_plot(avg_aggr_imp, features, along=params.along, dataset="avg", task=params.task, model=params.model)
    plt.savefig(os.path.join(dest, "all_merged.png"), dpi=600)
    plt.close()

    if params.along == "layer":
        idx = get_layer_order(unique_groups, params.model, convert=True)
        unique_groups_sorted = np.array(unique_groups)[idx]
    else:
        unique_groups_sorted = natsorted(unique_groups)

    extra_artist = group_bar_plot(
        all_importances,
        groups,
        unique_groups_sorted,
        "Average relative importance",
        disp_features=params.disp_features,
        fig_kw={"figsize": (6.4, 3.2)},
        grid=True
    )
    plt.savefig(os.path.join(dest, "{}_all_merged_perc.png".format(params.model)),
                bbox_extra_artists=(extra_artist,), dpi=300)
    plt.close()

    print_best_source(all_importances, groups, n=None, perc=0.1)

    for dataset, filename, data in zip(params.datasets, filenames, all_data):
        importances = data["importances"]
        features = data["features"]
        groups = np.array([next(filter(lambda t: t[0] == params.along, l))[1] for l, i in features])
        groups = np.array([LAYER_NAME_DICT[e] if params.along == "layer" else network_name_dict[e] for e in groups])
        unique_groups = np.unique(groups)
        if params.along == "layer":
            plot_title = "{} - {} - {} - total imp.".format(dataset, params.task, params.model)
            idx = get_layer_order(unique_groups, params.model, convert=True)
            unique_groups_sorted = np.array(unique_groups)[idx]
        else:
            plot_title = "{} - {} - total imp.".format(dataset, params.task)
            unique_groups_sorted = natsorted(unique_groups)

        extra_artist = group_bar_plot(
            importances,
            groups,
            unique_groups_sorted,
            plot_title,
            disp_features=params.disp_features
        )
        plt.savefig(os.path.join(dest, "{}_{}_{}_total.png".format(params.model, dataset, params.task)),
                    bbox_extra_artists=(extra_artist,))
        plt.close()

        # plt.figure()
        # plt.bar(range(len(unique_labels)), mean_aggr, tick_label=unique_labels)
        # plt.xticks(rotation=90)
        # plt.ylabel("Average importances")
        # plt.title("{} - {} - average imp.".format(dataset, params.task))
        # plt.tight_layout()
        # plt.savefig(os.path.join(dest, "{}_{}_avg.png".format(dataset, params.task)))

        # x = np.array(numbers)
        # y = np.array([d["roc_auc"] if "roc_auc" in d else d["accuracy"] for d in data])
        # idx = np.argsort(x)
        # best = np.argmax(y)
        # metric_keys = ["cv_best_score", "accuracy", "precision", "recall", "f1-score", "roc_auc"]
        # scores = "\t".join(["{:1.4f}".format(float(data[best][metric])) for metric in metric_keys if metric in data[best]])
        #
        # print("{} {} {}".format(model, dataset, scores))
        # plt.figure()
        # plt.title("{} - {} (best: {} with {:1.4f})".format(model, dataset, x[np.argmax(y)], np.max(y)))
        # plt.plot(x[idx], y[idx], "x-")
        # plt.axvline(x=x[np.argmax(y)], ymin=0, ymax=1, color="k", linestyle="--", alpha=0.5)
        # plt.ylim((0, 1))
        # plt.xlabel("layer")
        # plt.ylabel("roc_auc" if "roc_auc" in data[0] else "accuracy")
        # plt.savefig(os.path.join(dest, "{}_{}.png".format(model, dataset)))
        # plt.close()


if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
