
from clustertools import Computation, set_stdout_logging, ConstrainedParameterSet, Experiment
from clustertools.environment import SlurmEnvironment, InSituEnvironment

from util import infer_folder_name


class ReselectRFE(Computation):
    def run(self, **dict_params):
        import os
        import sys
        import numpy as np
        from sklearn.ensemble import ExtraTreesClassifier
        from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score, roc_auc_score
        from sklearn.preprocessing import LabelEncoder
        from sklearn.utils import check_random_state
        from sklearn.feature_selection import RFE
        from collections import namedtuple
        from util import load_features_of, extract_prefixes, sizeof_fmt, print_cm, mk_task_filename

        params_type = namedtuple("params_type", dict_params.keys())
        params = params_type(**dict_params)
        print("Parameters: {}".format(params))

        if not os.path.exists(params.rfe_dest):
            os.makedirs(params.rfe_dest)

        print("Load features")
        print("> train")
        x_train_names, y_train, x_train_trans = load_features_of(
            path=params.path,
            model=params.model,
            dataset=params.dataset,
            layer=params.layer,
            folder=params.train,
            reduction=params.reduction,
            source=params.source
        )
        train_labels = extract_prefixes([os.path.basename(s) for s in x_train_names])
        print("> test")
        x_test_names, y_test, x_test_trans = load_features_of(
            path=params.path,
            model=params.model,
            dataset=params.dataset,
            layer=params.layer,
            folder=params.test,
            reduction=params.reduction,
            source=params.source
        )

        print("Load RFE data from '{}'".format(params.rfe_path))
        prev_rfe = np.load(os.path.join(params.rfe_path, mk_task_filename(
            task="rfe",
            model=params.model,
            dataset=params.dataset,
            folder=None,
            layer=params.layer,
            reduction=params.reduction,
            source=params.source
        )))

        print("Reselect n_features:")
        rfe_scores = prev_rfe["scores"]
        max_rfe_score = np.max(rfe_scores)
        n_feat_index = np.where(rfe_scores > max_rfe_score - params.acceptable_loss)[0][0]
        new_n_features = 1 + params.old_step * n_feat_index
        print("> new n_features: {}".format(new_n_features))
        print("> ratio         : {:3.2f}%".format(100 * new_n_features / x_train_trans.shape[1]))
        print("> new score     : {}".format(prev_rfe["scores"][n_feat_index]))
        print("> old n_features: {}".format(prev_rfe["n_features"]))
        print("> old score     : {}".format(max_rfe_score))
        print("> delta         : {}".format(max_rfe_score - prev_rfe["scores"][n_feat_index]))

        print("Encode classes")
        encoder = LabelEncoder()
        encoder.fit(y_train)
        n_classes = encoder.classes_.shape[0]
        print("> classes: {}".format(encoder.classes_))
        print("> mapped : {}".format(np.arange(n_classes)))

        print("Loaded data")
        print("> x_train_names : {} (size: {})".format(x_train_names.shape, sizeof_fmt(x_train_names.nbytes)))
        print("> y_train       : {} (size: {})".format(y_train.shape, sizeof_fmt(y_train.nbytes)))
        print("> x_train_trans : {} (size: {})".format(x_train_trans.shape, sizeof_fmt(x_train_trans.nbytes)))
        print("> train_labels  : {} (size: {})".format(train_labels.shape, sizeof_fmt(train_labels.nbytes)))
        print("> x_test_names  : {} (size: {})".format(x_test_names.shape, sizeof_fmt(x_test_names.nbytes)))
        print("> y_test        : {} (size: {})".format(y_test.shape, sizeof_fmt(y_test.nbytes)))
        print("> x_test_trans  : {} (size: {})".format(x_test_trans.shape, sizeof_fmt(x_test_trans.nbytes)))

        print("Prepare RFE")
        random_state = check_random_state(params.seed)
        et = ExtraTreesClassifier(
            n_estimators=params.n_estimators,
            max_features=None,
            min_samples_leaf=1,
            random_state=random_state,
            n_jobs=params.n_jobs
        )
        rfe = RFE(
            et,
            n_features_to_select=new_n_features,
            step=params.step,
            verbose=10
        )
        print("Run RFE")
        sys.stdout.flush()
        enc_y_train = encoder.transform(y_train)
        rfe.fit(x_train_trans, enc_y_train)
        print(" -> n_features: {}".format(rfe.n_features_))
        print(" -> ranking:    {}".format(rfe.ranking_))
        print(" -> support:    {}".format(rfe.support_))

        best_model = rfe.estimator_
        x_test_trans = x_test_trans[:, rfe.support_]

        print("Predict test...")
        if n_classes == 2:
            print("> binary problem")
            probas = best_model.predict_proba(x_test_trans)
            y_pred = np.argmax(probas, axis=1).astype(np.int)
            print("> decision_function : {}".format(probas.shape))
        else:
            print("> multi-class problem")
            y_pred = best_model.predict(x_test_trans)
        y_pred = encoder.inverse_transform(y_pred)
        print("> y_pred            : {}".format(y_pred.shape))

        label_params = {"labels": encoder.classes_, "pos_label": encoder.classes_[1]}
        to_return = {
            "accuracy": accuracy_score(y_test, y_pred),
            "cm": confusion_matrix(y_test, y_pred),
            "support": rfe.support_,
            "ranking": rfe.ranking_,
            "n_features": rfe.n_features_,
            "step": params.step
        }

        print("Scores:")
        print("> accuracy : {}".format(to_return["accuracy"]))
        if n_classes == 2:
            to_return["precision"] = precision_score(y_test, y_pred, **label_params)
            to_return["recall"] = recall_score(y_test, y_pred, **label_params)
            to_return["f1-score"] = f1_score(y_test, y_pred, **label_params)
            to_return["roc_auc"] = roc_auc_score(encoder.transform(y_test), probas[:, 1])
            print("> precision: {}".format(to_return["precision"]))
            print("> recall   : {}".format(to_return["recall"]))
            print("> f1-score : {}".format(to_return["f1-score"]))
            print("> roc auc  : {}".format(to_return["roc_auc"]))

        print_cm(to_return["cm"], [str(i) for i in np.unique(y_test)])

        # save importances
        filename = mk_task_filename(
            task="rfe_rerank",
            model=params.model,
            layer=params.layer,
            dataset=params.dataset,
            folder=None,
            reduction=params.reduction,
            source=params.source
        )
        filepath = os.path.join(params.rfe_dest, filename)
        importances = best_model.feature_importances_
        print("> save rfe results to '{}'".format(filepath))
        np.savez(filepath, importances=importances, **to_return)
        return to_return


if __name__ == "__main__":
    set_stdout_logging()
    param_set = ConstrainedParameterSet()
    param_set.add_parameters(dataset=[
        "ulg_lbtd2_chimio_necrose",
        "cells_no_aug",
        "patterns_no_aug",
        "ulg_lbtd_lba",
        "ulb_anapath_lba"
    ])
    param_set.add_parameters(model=[
        "dense_net_201",
        "resnet50",
        "inception_resnet_v2",
        "inception_v3",
        "mobile",
        "vgg16",
        "vgg19"
    ])
    param_set.add_parameters(test=["test", "val"])
    param_set.add_constraints(test_folder_constrain=infer_folder_name)

    n_jobs = 16
    param_set.add_parameters(
        path="[...]/features/",
        rfe_path="[...]/features/rfe",
        rfe_dest="[...]/features/rfe_rerank",
        train="train", n_estimators=1000, acceptable_loss=0.005, source="image_net",
        step=8, old_step=8, reduction=None, layer=None, n_jobs=n_jobs, seed=42
    )
    param_set.add_separator()
    param_set.add_parameters(dataset=["glomeruli_no_aug"])
    param_set.add_separator()
    param_set.add_parameters(dataset=["ulg_lbtd_tissus"])
    param_set.add_separator()
    param_set.add_parameters(dataset=["ulg_breast"])
    param_set.add_separator()
    param_set.add_parameters(dataset=["ulg_lbtd_lba_new"])
    experiment = Experiment("rfe_rerank", param_set, ReselectRFE)
    SlurmEnvironment(
        time="16-23",
        memory="31900M",
        partition="[...]",
        n_proc=n_jobs
    ).run(experiment)
