from argparse import ArgumentParser
from matplotlib import pyplot as plt
import numpy as np
import os


def select_by_percentage(data, drops=None, std=False):
    if drops is None:
        drops = [0.005, 0.01, 0.02]
    if std:
        drops.append(np.std(data))
    max_score = np.max(data)
    matching = np.array([data > max_score - drop for drop in drops])
    return np.array([np.where(row)[0][0] for row in matching]), np.array(drops)


def plot_rfe(scores, n_features, step, dest, metric, loss=0.01, title="", **save_params):
    # prepare data
    n_evaluated = scores.shape[0]
    x = n_features - np.arange(n_evaluated)[::-1] * step
    best = np.argmax(scores)
    lower_threshold = scores[best] - loss
    selected_n_features = np.where(scores > lower_threshold)[0][0]

    plt.figure()
    # lower plot
    plt.subplot(212)
    plt.plot(x, scores, linewidth=0.95)
    plt.axvline(x[best], linewidth=0.5, color="k", linestyle="--", alpha=0.65, label="best")
    plt.axvline(x[selected_n_features], linewidth=0.5, color="r", linestyle="--", alpha=0.80, label="selected")
    plt.ylim((0, 1))
    plt.ylabel(metric)
    plt.xlabel("n_features")

    # upper plot
    plt.subplot(211)
    plt.title("{} (loss:{} sel:{} ({:3.2f}%)".format(title, loss, x[selected_n_features], 100 * x[selected_n_features] / n_features))
    plt.plot(x, scores, linewidth=0.95)
    plt.axvline(x[best], linewidth=0.5, color="k", linestyle="--", alpha=0.65, label="best")
    plt.axvline(x[selected_n_features], linewidth=0.5, color="r", linestyle="--", alpha=0.80, label="selected")
    plt.axhline(lower_threshold, linewidth=0.75, color="k", linestyle=":", label="threshold")
    plt.ylabel(metric)
    plt.legend()

    # save
    plt.savefig(dest, **save_params)
    plt.close()

def main(argv):
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path", dest="path")
    parser.add_argument("--models", dest="models")
    parser.add_argument("--datasets", dest="datasets")
    parser.add_argument("--dest", dest="dest", default=".")
    parser.add_argument("--step", dest="step", type=int, default=4)
    parser.add_argument("--loss", dest="loss", type=float, default=0.005)
    params, unknown = parser.parse_known_args(argv)
    params.models = params.models.split(",")
    params.datasets = params.datasets.split(",")
    print("Parameters: {}".format(params))

    for path in [params.dest]:
        if not os.path.exists(path):
            os.makedirs(path)

    print("> load data")
    model_dataset = {(model, dataset): "{}_image_net_{}".format(model, dataset) for model in params.models for dataset in params.datasets}
    filenames = {k: "{}_rfe.npz".format(name) for k, name in model_dataset.items()}
    data = {k: np.load(os.path.join(params.path, filename)) for k, filename in filenames.items() if os.path.isfile(os.path.join(params.path, filename))}
    metric_keys = ["accuracy", "precision", "recall", "f1-score", "roc_auc", "n_features"]

    for model, dataset in data.keys():
        if dataset == "ulg_breast" or dataset == "ulg_lbtd_lba_new":
            step = 8
        else:
            step = params.step
        curr_data = data[(model, dataset)]
        print("> {}\t{}\t{}".format(model, dataset, "\t".join(["{:1.4f}".format(float(curr_data[metric])) for metric in metric_keys if metric in curr_data])))

        n_evaluated = curr_data["scores"].shape[0]
        x = 1 + np.arange(n_evaluated) * step
        drops = [params.loss]
        indexes, drops = select_by_percentage(curr_data["scores"], drops=drops)
        print("> max score {}; n_features {}".format(np.max(curr_data["scores"]), curr_data["n_features"]))
        print("> {}".format([(drop, curr_data["scores"][index], x[index]) for drop, index in zip(drops, indexes)]))

        plot_rfe(
            scores=curr_data["scores"],
            n_features=curr_data["ranking"].shape[0],
            step=step,
            metric="accuracy",
            dest=os.path.join(params.dest, "{}_{}_resel.png".format(model, dataset)),
            loss=params.loss,
            title="{} {}".format(model, dataset),
            dpi=600
        )



if __name__ == "__main__":
    import sys
    main(sys.argv[1:])