import os
import math
from warnings import warn

import numpy as np
from argparse import ArgumentParser
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt


def histogram(dest, x, title="", xlabel="", ylabel=""):
    plt.figure()
    plt.hist(x, bins=int(x.shape[0] / 5), normed=True)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xlim((0, x.shape[0]))
    plt.savefig(dest)
    plt.close()


def bar(dest, x, width=1.0, xlabel="", ylabel="", title="bar", dpi=300, max_disp=500):
    position = np.arange(x.shape[0]) + 1
    n_features = x.shape[0]
    if n_features > max_disp:
        step = int(n_features / max_disp)
        data = x[::step]
        position = position[::step]
    else:
        data = x

    # positioning and limits
    plt.figure()
    vmin, vmax = np.min(x), np.max(x)
    plt.ylim((max(1e-6, vmin), vmax + (vmax - vmin) * 0.05))
    padding = int(n_features * 0.05)
    plt.xlim((-padding, n_features + padding))

    # actually plot
    plt.bar(position, data, width=width, align="edge")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(dest, dpi=dpi)
    plt.close()


def subplot_scatters(ref, cmps, ref_label, cmps_labels, dest, title="", dpi=300, logscale=False, figsize=(600, 400)):
    n_cmp = cmps.shape[0]
    grid_size = int(math.ceil(np.sqrt(n_cmp)))
    fig, axes = plt.subplots(
        nrows=grid_size - int((grid_size * grid_size - n_cmp) / grid_size),
        ncols=grid_size,
        squeeze=False
    )
    fig.suptitle(title)

    for i, cmp in enumerate(cmps):
        ax = axes[i / grid_size][i % grid_size]
        idx = np.ones(ref.shape, dtype=np.bool)
        if logscale:
            idx = np.logical_and(ref > 0, cmp > 0)
            ax.set_xscale("log")
            ax.set_yscale("log")
        overall_max = max(np.max(ref[idx]), np.max(cmp[idx]))
        overall_min = min(np.min(ref[idx]), np.min(cmp[idx]))
        ax.scatter(cmp[idx], ref[idx], alpha=0.5, linewidths=1)
        ax.set_xlabel(cmps_labels[i])
        ax.set_ylabel(ref_label)
        identity = np.linspace(overall_min, overall_max, num=3)
        ax.plot(identity, identity, "k--")

    plt.tight_layout(w_pad=1, h_pad=1, rect=[0, 0.03, 1, 0.95])
    plt.savefig(dest, dpi=dpi)
    plt.close()


def main(argv):
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path_imp", dest="path_imp")
    parser.add_argument("--path_rer", dest="path_rer")
    parser.add_argument("--models", dest="models")
    parser.add_argument("--datasets", dest="datasets")
    params, unknown = parser.parse_known_args(argv)
    params.models = params.models.split(",")
    params.datasets = params.datasets.split(",")
    print("Parameters: {}".format(params))

    print("> load data")
    model_dataset = {(model, dataset): "{}_image_net_{}".format(model, dataset) for model in params.models for dataset in params.datasets}
    filenames = {k: "{}_importances.npz".format(name) for k, name in model_dataset.items()}
    data = {k: np.load(os.path.join(params.path_imp, filename)) for k, filename in filenames.items() if os.path.isfile(os.path.join(params.path_imp, filename))}

    # histograms and bars
    for model in params.models:
        missing = [dataset for dataset in params.datasets if (model, dataset) not in data]
        if len(missing) > 0:
            warn("{} are misssing for network: {}".format(missing, model))
            continue
        importances = np.array([data[(model, dataset)]["importances"] for dataset in params.datasets]).transpose()
        n_features, n_samples = importances.shape
        ranks = np.empty_like(importances, dtype=np.int)
        for i in range(n_samples):
            idx = np.argsort(importances[:, i])
            ranks[:, i][idx] = np.arange(idx.shape[0])
        size = np.max(np.min(ranks, axis=1))
        print("For '{}', subset size: {} ({:3.2f}%)".format(model, size, 100 * size / float(n_features)))

    print("> load data (cross dataset)")
    model_dataset = {(model, dataset): "{}_image_net_{}".format(model, dataset) for model in params.models for dataset in params.datasets}
    filenames = {k: "{}_rfe_rerank.npz".format(name) for k, name in model_dataset.items()}
    data = {k: np.load(os.path.join(params.path_rer, filename)) for k, filename in filenames.items() if os.path.isfile(os.path.join(params.path_rer, filename))}

    sorted_dataset = ['cells_no_aug', 'patterns_no_aug', 'glomeruli_no_aug', 'ulg_lbtd2_chimio_necrose', 'ulg_breast', 'ulg_lbtd_lba_new', 'ulg_lbtd_tissus', 'ulb_anapath_lba']
    sorted_names = ["C", "P", "G", "N", "B", "M", "L", "H"]
    datasets, names = zip(*[(dataset, name) for dataset, name in zip(sorted_dataset, sorted_names) if dataset in set(sorted_dataset)])
    for model in params.models:
        print("{}\t".format(model) + "\t".join(names))
        for i, _from in enumerate(datasets):
            print(end="\t")
            for j, _to in enumerate(datasets):
                if _from == _to:
                    print(end=" \t")
                    continue
                from_support = data[(model, _from)]["support"]
                to_support = data[(model, _to)]["support"]
                count = np.count_nonzero(np.logical_and(from_support, to_support))
                perc = float(count) / np.count_nonzero(data[(model, datasets[j])]["support"])
                print(end="{:3.1f}\t".format(perc * 100))
                #print(end="{}\t".format(count))
            print()
        print()

if __name__ == "__main__":
    import sys
    main(sys.argv[1:])
