from clustertools import Computation, ParameterSet, Experiment, ConstrainedParameterSet, set_stdout_logging
from clustertools.environment import Environment, SlurmEnvironment, InSituEnvironment


class RFOnMergedComputation(Computation):
    def run(self, **dict_params):
        import os
        import numpy as np
        from collections import namedtuple
        from sklearn.metrics import accuracy_score, make_scorer
        from sklearn.metrics import confusion_matrix
        from sklearn.metrics import f1_score
        from sklearn.metrics import precision_score
        from sklearn.metrics import recall_score
        from sklearn.metrics import roc_auc_score
        from sklearn.model_selection import GroupKFold, GridSearchCV
        from sklearn.preprocessing import LabelEncoder
        from sklearn.ensemble import ExtraTreesClassifier
        from sklearn.utils import check_random_state
        from sklearn.svm import LinearSVC
        from util import extract_prefixes, sizeof_fmt, print_cm, mk_task_filename, load_merged_features

        params_type = namedtuple("params_type", dict_params.keys())
        params = params_type(**dict_params)
        print("Parameters: {}".format(params))

        task = "merged_layer_" + params.method
        dest = os.path.join(params.dest, task, params.model)
        if not os.path.exists(dest):
            os.makedirs(dest)
        path = os.path.join(params.path, params.model)

        print("Load features")
        print("> train")
        x_train_names, y_train, x_train_trans, features = load_merged_features(
            path=path,
            model=params.model,
            dataset=params.dataset,
            folder=params.train,
            reduction=params.reduction,
            source=params.source
        )
        train_labels = extract_prefixes([os.path.basename(s) for s in x_train_names])
        print("> test")
        x_test_names, y_test, x_test_trans, _ = load_merged_features(
            path=path,
            model=params.model,
            dataset=params.dataset,
            folder=params.test,
            reduction=params.reduction,
            source=params.source
        )

        print("Encode classes")
        encoder = LabelEncoder()
        encoder.fit(y_train)
        n_classes = encoder.classes_.shape[0]
        print("> classes: {}".format(encoder.classes_))
        print("> mapped : {}".format(np.arange(n_classes)))

        print("Loaded data")
        print("> x_train_names : {} (size: {})".format(x_train_names.shape, sizeof_fmt(x_train_names.nbytes)))
        print("> y_train       : {} (size: {})".format(y_train.shape, sizeof_fmt(y_train.nbytes)))
        print("> x_train_trans : {} (size: {})".format(x_train_trans.shape, sizeof_fmt(x_train_trans.nbytes)))
        print("> train_labels  : {} (size: {})".format(train_labels.shape, sizeof_fmt(train_labels.nbytes)))
        print("> x_test_names  : {} (size: {})".format(x_test_names.shape, sizeof_fmt(x_test_names.nbytes)))
        print("> y_test        : {} (size: {})".format(y_test.shape, sizeof_fmt(y_test.nbytes)))
        print("> x_test_trans  : {} (size: {})".format(x_test_trans.shape, sizeof_fmt(x_test_trans.nbytes)))

        print("Prepare cross-validation")
        n_samples, n_features = x_train_trans.shape
        unique_labels = np.unique(train_labels)

        if params.method == "et":
            n_splits = min(5, unique_labels.shape[0])
            random_state = check_random_state(42)
            model = ExtraTreesClassifier(n_estimators=params.n_estimators, random_state=random_state, n_jobs=params.n_jobs)
            grid = {
                "min_samples_leaf": [1],
                "max_features": [1, int(np.sqrt(n_features)), n_features // 2, n_features]
            }
        elif params.method == "svm":
            n_splits = min(10, unique_labels.shape[0])
            model = LinearSVC()
            grid = {"C": [1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0]}
        else:
            raise ValueError("Unknown method '{}'".format(params.method))

        kfold = GroupKFold(n_splits=n_splits)
        print("> method  : {}".format(params.method))
        print("> n_splits: {}".format(n_splits))
        print("> grid    : {}".format(grid))

        grid_search = GridSearchCV(
            model, grid, cv=kfold, scoring=make_scorer(accuracy_score),
            verbose=10, n_jobs=params.n_jobs if params.method == "svm" else 1, refit=False
        )
        print("Run cross-validation")
        y_train_trans = encoder.transform(y_train)
        grid_search.fit(x_train_trans, y_train_trans, train_labels)
        print(" -> best_params: {}".format(grid_search.best_params_))
        print(" -> best_score: {}".format(grid_search.best_score_))
        print(" -> results: {}".format(grid_search.cv_results_))

        print("> refit")
        model.n_jobs = params.n_jobs
        model.fit(x_train_trans, y_train_trans)
        best_model = model

        print("Predict test...")
        if n_classes == 2:
            print("> binary problem")
            try:
                probas = best_model.predict_proba(x_test_trans)
                thresholds = probas[:, 1]
                y_pred = np.argmax(probas, axis=1).astype(np.int)
            except (AttributeError, NotImplementedError):
                thresholds = best_model.decision_function(x_test_trans)
                y_pred = (thresholds >= 0).astype(np.int)
            print("> decision_function : {}".format(thresholds.shape))
        else:
            print("> multi-class problem")
            y_pred = best_model.predict(x_test_trans)
        y_pred = encoder.inverse_transform(y_pred)
        print("> y_pred            : {}".format(y_pred.shape))

        label_params = {"labels": encoder.classes_, "pos_label": encoder.classes_[1]}
        to_return = {
            "accuracy": accuracy_score(y_test, y_pred),
            "cm": confusion_matrix(y_test, y_pred),
            "cv_best_score": float(grid_search.best_score_),
        }

        for k, v in grid_search.best_params_.items():
            to_return[k] = v

        print("Scores:")
        print("> accuracy : {}".format(to_return["accuracy"]))
        if n_classes == 2:
            to_return["precision"] = precision_score(y_test, y_pred, **label_params)
            to_return["recall"] = recall_score(y_test, y_pred, **label_params)
            to_return["f1-score"] = f1_score(y_test, y_pred, **label_params)
            to_return["roc_auc"] = roc_auc_score(encoder.transform(y_test), thresholds)
            print("> precision: {}".format(to_return["precision"]))
            print("> recall   : {}".format(to_return["recall"]))
            print("> f1-score : {}".format(to_return["f1-score"]))
            print("> roc auc  : {}".format(to_return["roc_auc"]))

        print_cm(to_return["cm"], [str(i) for i in np.unique(y_test)])

        # add features information
        to_return["features"] = features
        if hasattr(best_model, "feature_importances_"):
            to_return["importances"] = best_model.feature_importances_

        # save results
        filename = mk_task_filename(
            task=task,
            model=params.model,
            layer="merged",
            dataset=params.dataset,
            folder=None,
            reduction=params.reduction,
            source=params.source
        )
        filepath = os.path.join(dest, filename)
        print("> save results for task '{}' to '{}'".format(task, filepath))
        np.savez(filepath, **to_return)
        return to_return


def infer_folder_name(dataset, test, **kwargs):
    names = {"ulg_lbtd2_chimio_necrose", "ulg_lbtd_lba", "ulb_anapath_lba", "ulg_lbtd_tissus"}
    return (dataset in names and test == "val") or (dataset not in names and test == "test")


if __name__ == "__main__":
    set_stdout_logging()
    param_set = ConstrainedParameterSet()
    param_set.add_parameters(dataset=[
        "cells_no_aug",
        "patterns_no_aug",
        "glomeruli_no_aug",
        "ulg_lbtd2_chimio_necrose",
        "ulg_lbtd_lba",
        "ulb_anapath_lba",
        "ulg_lbtd_tissus"
    ])
    param_set.add_parameters(test=["val", "test"])
    param_set.add_constraints(test_set_cons=infer_folder_name)
    param_set.add_parameters(method=["et", "svm"])
    n_jobs = 8

    param_set.add_parameters(
        path="[...]/features/per_layer/",
        dest="[...]/features/per_layer/",
        train="train", model="resnet50", source=None, reduction=None,
        n_jobs=n_jobs, n_estimators=1000
    )
    param_set.add_separator()
    param_set.add_parameters(model=[
        "dense_net_201"
    ])
    param_set.add_separator()
    param_set.add_parameters(model=[
        "inception_v3"
    ])
    param_set.add_separator()
    param_set.add_parameters(model=[
        "inception_resnet_v2"
    ])
    param_set.add_separator()
    param_set.add_parameters(dataset=[
        "ulg_breast"
    ])
    param_set.add_separator()
    param_set.add_parameters(dataset=[
        "ulg_lbtd_lba_new"
    ])
    experiment = Experiment("merged_layers", param_set, RFOnMergedComputation)
    SlurmEnvironment(
        time="8-8",
        memory="23900M",
        partition="[...]",
        n_proc=n_jobs
    ).run(experiment)
