import os
import math
import numpy as np
from argparse import ArgumentParser
import matplotlib
import re

from natsort import natsorted

matplotlib.use('Agg')
from matplotlib import pyplot as plt

LAYER_NAME_DICT = {
    "activation_1": "acti_1",
    "activation_4": "acti_4",
    "activation_7": "acti_7",
    "activation_10": "acti_10",
    "activation_13": "acti_13",
    "activation_16": "acti_16",
    "activation_19": "acti_19",
    "activation_22": "acti_22",
    "activation_25": "acti_25",
    "activation_28": "acti_28",
    "activation_31": "acti_31",
    "activation_34": "acti_34",
    "activation_37": "acti_37",
    "activation_40": "acti_40",
    "activation_43": "acti_43",
    "activation_46": "acti_46",
    "activation_49": "last",
    "pool1": "pool1",
    "conv2_block6_concat": "conv2_block6",
    "pool2_pool": "pool2_pool",
    "conv3_block12_concat": "conv3_block12",
    "pool3_pool": "pool3_pool",
    "conv4_block48_concat": "conv4_block48",
    "pool4_pool": "pool4_pool",
    "conv5_block32_concat": "conv5_block32",
    "bn": "last",
    "max_pooling2d_2": "max_pool2d_2",
    "mixed_5b": "mixed_5b",
    "block35_1_ac": "block35_1_ac",
    "block35_4_ac": "block35_4_ac",
    "block35_7_ac": "block35_7_ac",
    "block35_10_ac": "block35_10_ac",
    "mixed_6a": "mixed_6a",
    "block17_5_ac": "block17_5_ac",
    "block17_10_ac": "block17_10_ac",
    "block17_15_ac": "block17_15_ac",
    "block17_20_ac": "block17_20_ac",
    "block8_3_ac": "block8_3_ac",
    "block8_6_ac": "block8_6_ac",
    "block8_9_ac": "block8_9_ac",
    "mixed_7a": "mixed_7a",
    "conv_7b_ac": "last",
    "mixed0": "mixed0",
    "mixed1": "mixed1",
    "mixed2": "mixed2",
    "mixed3": "mixed3",
    "mixed4": "mixed4",
    "mixed5": "mixed5",
    "mixed6": "mixed6",
    "mixed7": "mixed7",
    "mixed8": "mixed8",
    "mixed9": "mixed9",
    "mixed10": "last",
}


def max_n_char(str, n=15):
    if len(str) < n:
        return str
    return str[:math.ceil(n / 2) - 1] + "..." + str[-(math.ceil(n / 2) - 2):]


def get_layer_order(layers, model, convert=False):
    if model == "resnet50":
        return np.array(natsorted(range(len(layers)), key=lambda i: layers[i]))

    if model == "inception_resnet_v2":
        ordered = [
            "max_pooling2d_2", "mixed_5b", "block35_1_ac", "block35_4_ac", "block35_7_ac", "block35_10_ac", "mixed_6a",
            "block17_5_ac", "block17_10_ac", "block17_15_ac", "block17_20_ac", "mixed_7a", "block8_3_ac", "block8_6_ac",
            "block8_9_ac", "conv_7b_ac"
        ]
    elif model == "dense_net_201":
        ordered = [
            "pool1", "conv2_block6_concat", "pool2_pool", "conv3_block12_concat", "pool3_pool", "conv4_block48_concat",
            "pool4_pool", "conv5_block32_concat", "bn"
        ]
    elif model == "inception_v3":
        ordered = [
            "max_pooling2d_2", "mixed0", "mixed1", "mixed2", "mixed3", "mixed4", "mixed5", "mixed6", "mixed7", "mixed8",
            "mixed9", "mixed10"
        ]
    else:
        raise ValueError("No such model '{}'".format(model))

    if convert:
        ordered = [LAYER_NAME_DICT[l] for l in ordered]

    order = {i: l for i, l in enumerate(ordered)}
    layer_pos = {l: i for i, l in enumerate(layers)}
    idx = -1 * np.ones(len(order), dtype=np.int)
    for i, layer in enumerate(layers):
        idx[i] = layer_pos[order[i]]

    return idx


def histogram(dest, x, title="", xlabel="", ylabel=""):
    plt.figure()
    plt.hist(x, bins=int(x.shape[0] / 5), normed=True)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xlim((0, x.shape[0]))
    plt.savefig(dest)
    plt.close()


def bar(dest, x, width=1.0, xlabel="", ylabel="", title="bar", dpi=300, max_disp=500):
    position = np.arange(x.shape[0]) + 1
    n_features = x.shape[0]
    if n_features > max_disp:
        step = int(n_features / max_disp)
        data = x[::step]
        position = position[::step]
    else:
        data = x

    # positioning and limits
    plt.figure()
    vmin, vmax = np.min(x), np.max(x)
    plt.ylim((max(1e-6, vmin), vmax + (vmax - vmin) * 0.05))
    padding = int(n_features * 0.05)
    plt.xlim((-padding, n_features + padding))

    # actually plot
    plt.bar(position, data, width=width, align="edge")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(dest, dpi=dpi)
    plt.close()


def subplot_scatters(ref, cmps, ref_label, cmps_labels, dest, title="", dpi=300, logscale=False, figsize=(600, 400)):
    n_cmp = cmps.shape[0]
    grid_size = int(math.ceil(np.sqrt(n_cmp)))
    fig, axes = plt.subplots(
        nrows=grid_size - int((grid_size * grid_size - n_cmp) / grid_size),
        ncols=grid_size,
        squeeze=False
    )
    fig.suptitle(title)

    for i, cmp in enumerate(cmps):
        ax = axes[i / grid_size][i % grid_size]
        idx = np.ones(ref.shape, dtype=np.bool)
        if logscale:
            idx = np.logical_and(ref > 0, cmp > 0)
            ax.set_xscale("log")
            ax.set_yscale("log")
        overall_max = max(np.max(ref[idx]), np.max(cmp[idx]))
        overall_min = min(np.min(ref[idx]), np.min(cmp[idx]))
        ax.scatter(cmp[idx], ref[idx], alpha=0.5, linewidths=1)
        ax.set_xlabel(cmps_labels[i])
        ax.set_ylabel(ref_label)
        identity = np.linspace(overall_min, overall_max, num=3)
        ax.plot(identity, identity, "k--")

    plt.tight_layout(w_pad=1, h_pad=1, rect=[0, 0.03, 1, 0.95])
    plt.savefig(dest, dpi=dpi)
    plt.close()


def main(argv):
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path", dest="path")
    parser.add_argument("--models", dest="models")
    parser.add_argument("--datasets", dest="datasets")
    parser.add_argument("--dest", dest="dest", default=".")
    params, unknown = parser.parse_known_args(argv)
    params.models = params.models.split(",")
    params.datasets = params.datasets.split(",")
    print("Parameters: {}".format(params))

    dest = os.path.join(params.dest, "plots")
    if not os.path.exists(dest):
        os.makedirs(dest)

    all_files = os.listdir(params.path)
    for model in params.models:
        for dataset in params.datasets:
            regex = r"^{}_([_a-z0-9]*)_image_net_{}_svm.npz$".format(model, dataset)
            matches = [re.match(pattern=regex, string=file) for file in all_files if os.path.isfile(os.path.join(params.path, file))]
            files = [match.group(0) for match in matches if match is not None]
            layers = [match.group(1) for match in matches if match is not None]
            data = [np.load(os.path.join(params.path, file)) for file in files]

            idx = get_layer_order(layers, model)
            x = np.array(layers)
            y = np.array([d["roc_auc"] if "roc_auc" in d else d["accuracy"] for d in data])
            best = np.argmax(y[idx])
            metric_keys = ["cv_best_score", "accuracy", "precision", "recall", "f1-score", "roc_auc"]
            scores = "\t".join(["{:1.4f}".format(float(data[best][metric])) for metric in metric_keys if metric in data[best]])

            print("{} {} {}".format(model, dataset, scores))
            plt.figure()
            plt.title("{} - {} (best: {} with {:1.4f})".format(model, max_n_char(dataset, n=10), x[np.argmax(y)], np.max(y)))
            label_pos = np.arange(idx.shape[0])
            plt.plot(label_pos, y[idx], "x-")
            plt.xticks(label_pos, [max_n_char(v) for v in x[idx]], rotation=55, ha="right")
            plt.axvline(x=best, ymin=0, ymax=1, color="k", linestyle="--", alpha=0.5)
            plt.ylim((0, 1))
            plt.xlabel("layer")
            plt.ylabel("roc_auc" if "roc_auc" in data[0] else "accuracy")
            plt.tight_layout()
            plt.savefig(os.path.join(dest, "{}_{}.png".format(model, dataset)))
            plt.close()

    #
    # print("> load data")
    # model_dataset = {(model, dataset): "{}_{}".format(model, dataset) for model in params.models for dataset in params.datasets}
    # filenames = {k: "{}_importances.npz".format(name) for k, name in model_dataset.items()}
    # data = {k: np.load(os.path.join(params.path, filename)) for k, filename in filenames.items()}

    # # histograms and bars
    # for model in params.models:
    #     for dataset in params.datasets:
    #         print("> for '{}_{}'".format(model, dataset))
    #         importances = data[(model, dataset)]["importances"]
    #         n_features = importances.shape[0]
    #         # histogram(dest=os.path.join(histogram_dest, "{}_{}.png".format(model, dataset)), x=importances)
    #
    #         idx = np.argsort(-importances)  # sort by decreasing order
    #         bar(
    #             os.path.join(bar_plots_dest, "{}_{}_all.png".format(model, dataset)),
    #             x=importances[idx],
    #             ylabel="importances",
    #             title="{} ({} feat.), {}".format(model, n_features, dataset[:50])
    #         )
    #         bar(
    #             os.path.join(bar_plots_dest, "best_{}_{}.png".format(model, dataset)),
    #             x=importances[idx][:int(len(importances) * 0.1)],
    #             ylabel="importances",
    #             title="best, {} ({} feat.), {}".format(model, n_features, dataset[:50])
    #         )

    # ref = params.datasets[0]
    # for model in params.models:
    #     print("> scatters for {}".format(model))
    #     ref_imp = data[(model, ref)]["importances"]
    #     cmp_imps = np.array([data[(model, dataset)]["importances"] for dataset in params.datasets[1:]])
    #     subplot_scatters(
    #         ref_imp, cmp_imps,
    #         ref_label=ref,
    #         cmps_labels=params.datasets[1:],
    #         dest=os.path.join(scatter_dest, "{}_{}_vs_rest_log.png".format(model, ref)),
    #         title="{} ({} feat.), {} vs rest, log".format(model, len(ref_imp), ref),
    #         logscale=True
    #     )
    #
    #     subplot_scatters(
    #         ref_imp, cmp_imps,
    #         ref_label=ref,
    #         cmps_labels=params.datasets[1:],
    #         dest=os.path.join(scatter_dest, "{}_{}_vs_rest.png".format(model, ref)),
    #         title="{} ({} feat.), {} vs rest, log".format(model, len(ref_imp), ref),
    #         logscale=False
    #     )
    #
    #     print("> rank histograms for {}".format(model))
    #     ranks = np.array([np.argsort(-data[(model, d)]["importances"]) for d in params.datasets])
    #     max_rank = np.max(ranks, axis=0)
    #     mean_rank = np.mean(ranks, axis=0)
    #     min_rank = np.min(ranks, axis=0)
    #     n_features = max_rank.shape[0]
    #
    #     histogram(
    #         os.path.join(histogram_dest, "{}_{}.png".format(model, "max_rank")),
    #         max_rank,
    #         title="Max rank ({}, n_features:{})".format(model, n_features)
    #     )
    #     histogram(
    #         os.path.join(histogram_dest, "{}_{}.png".format(model, "min_rank")),
    #         min_rank,
    #         title="Min rank ({}, n_features:{})".format(model, n_features)
    #     )
    #     histogram(
    #         os.path.join(histogram_dest, "{}_{}.png".format(model, "mean_rank")),
    #         mean_rank,
    #         title="Mean rank ({}, n_features:{})".format(model, n_features)
    #     )
    #

if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
