import os
import math
import numpy as np
from argparse import ArgumentParser
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt


def histogram(dest, x, title="", xlabel="", ylabel=""):
    plt.figure()
    plt.hist(x, bins=int(x.shape[0] / 5), normed=True)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xlim((0, x.shape[0]))
    plt.savefig(dest)
    plt.close()


def bar(dest, x, width=1.0, xlabel="", ylabel="", title="bar", dpi=300, max_disp=500):
    position = np.arange(x.shape[0]) + 1
    n_features = x.shape[0]
    if n_features > max_disp:
        step = int(n_features / max_disp)
        data = x[::step]
        position = position[::step]
    else:
        data = x

    # positioning and limits
    plt.figure()
    vmin, vmax = np.min(x), np.max(x)
    plt.ylim((max(1e-6, vmin), vmax + (vmax - vmin) * 0.05))
    padding = int(n_features * 0.05)
    plt.xlim((-padding, n_features + padding))

    # actually plot
    plt.bar(position, data, width=width, align="edge")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(dest, dpi=dpi)
    plt.close()


def subplot_scatters(ref, cmps, ref_label, cmps_labels, dest, title="", dpi=300, logscale=False, figsize=(600, 400)):
    n_cmp = cmps.shape[0]
    grid_size = int(math.ceil(np.sqrt(n_cmp)))
    fig, axes = plt.subplots(
        nrows=grid_size - int((grid_size * grid_size - n_cmp) / grid_size),
        ncols=grid_size,
        squeeze=False
    )
    fig.suptitle(title)

    for i, cmp in enumerate(cmps):
        ax = axes[i / grid_size][i % grid_size]
        idx = np.ones(ref.shape, dtype=np.bool)
        if logscale:
            idx = np.logical_and(ref > 0, cmp > 0)
            ax.set_xscale("log")
            ax.set_yscale("log")
        overall_max = max(np.max(ref[idx]), np.max(cmp[idx]))
        overall_min = min(np.min(ref[idx]), np.min(cmp[idx]))
        ax.scatter(cmp[idx], ref[idx], alpha=0.5, linewidths=1)
        ax.set_xlabel(cmps_labels[i])
        ax.set_ylabel(ref_label)
        identity = np.linspace(overall_min, overall_max, num=3)
        ax.plot(identity, identity, "k--")

    plt.tight_layout(w_pad=1, h_pad=1, rect=[0, 0.03, 1, 0.95])
    plt.savefig(dest, dpi=dpi)
    plt.close()


def main(argv):
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path", dest="path")
    parser.add_argument("--models", dest="models")
    parser.add_argument("--datasets", dest="datasets")
    parser.add_argument("--dest", dest="dest", default=".")
    params, unknown = parser.parse_known_args(argv)
    params.models = params.models.split(",")
    params.datasets = params.datasets.split(",")
    print("Parameters: {}".format(params))

    histogram_dest = os.path.join(params.dest, "histograms")
    bar_plots_dest = os.path.join(params.dest, "bar_plots")
    scatter_dest = os.path.join(params.dest, "scatter_dest")

    for path in [histogram_dest, bar_plots_dest, scatter_dest]:
        if not os.path.exists(path):
            os.makedirs(path)

    print("> load data")
    model_dataset = {(model, dataset): "{}_{}".format(model, dataset) for model in params.models for dataset in params.datasets}
    filenames = {k: "{}_merged_model_et.npz".format(name) for k, name in model_dataset.items()}
    data = {k: np.load(os.path.join(params.path, filename)) for k, filename in filenames.items() if os.path.isfile(os.path.join(params.path, filename))}

    # histograms and bars
    for model in params.models:
        for dataset in params.datasets:
            if (model, dataset) not in data:
                continue
            print("> for '{}_{}'".format(model, dataset))
            importances = data[(model, dataset)]["importances"]
            n_features = importances.shape[0]
            # histogram(dest=os.path.join(histogram_dest, "{}_{}.png".format(model, dataset)), x=importances)

            idx = np.argsort(-importances)  # sort by decreasing order
            bar(
                os.path.join(bar_plots_dest, "{}_{}_all.png".format(model, dataset)),
                x=importances[idx],
                ylabel="importances",
                title="{} ({} feat.), {}".format(model, n_features, dataset[:50])
            )
            bar(
                os.path.join(bar_plots_dest, "best_{}_{}.png".format(model, dataset)),
                x=importances[idx][:int(len(importances) * 0.1)],
                ylabel="importances",
                title="best, {} ({} feat.), {}".format(model, n_features, dataset[:50])
            )

    ref = params.datasets[0]
    # for model in params.models:
    #     print("> scatters for {}".format(model))
    #     ref_imp = data[(model, ref)]["importances"]
    #     cmp_imps = np.array([data[(model, dataset)]["importances"] for dataset in params.datasets[1:]])
    #     subplot_scatters(
    #         ref_imp, cmp_imps,
    #         ref_label=ref,
    #         cmps_labels=params.datasets[1:],
    #         dest=os.path.join(scatter_dest, "{}_{}_vs_rest_log.png".format(model, ref)),
    #         title="{} ({} feat.), {} vs rest, log".format(model, len(ref_imp), ref),
    #         logscale=True
    #     )
    #
    #     subplot_scatters(
    #         ref_imp, cmp_imps,
    #         ref_label=ref,
    #         cmps_labels=params.datasets[1:],
    #         dest=os.path.join(scatter_dest, "{}_{}_vs_rest.png".format(model, ref)),
    #         title="{} ({} feat.), {} vs rest, log".format(model, len(ref_imp), ref),
    #         logscale=False
    #     )
    #
    #     print("> rank histograms for {}".format(model))
    #     ranks = np.array([np.argsort(-data[(model, d)]["importances"]) for d in params.datasets])
    #     max_rank = np.max(ranks, axis=0)
    #     mean_rank = np.mean(ranks, axis=0)
    #     min_rank = np.min(ranks, axis=0)
    #     n_features = max_rank.shape[0]
    #
    #     histogram(
    #         os.path.join(histogram_dest, "{}_{}.png".format(model, "max_rank")),
    #         max_rank,
    #         title="Max rank ({}, n_features:{})".format(model, n_features)
    #     )
    #     histogram(
    #         os.path.join(histogram_dest, "{}_{}.png".format(model, "min_rank")),
    #         min_rank,
    #         title="Min rank ({}, n_features:{})".format(model, n_features)
    #     )
    #     histogram(
    #         os.path.join(histogram_dest, "{}_{}.png".format(model, "mean_rank")),
    #         mean_rank,
    #         title="Mean rank ({}, n_features:{})".format(model, n_features)
    #     )
    #

if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
