from clustertools import Computation, ConstrainedParameterSet, set_stdout_logging, Experiment
from clustertools.environment import SlurmEnvironment, InSituEnvironment

from util import infer_folder_name


class PyxitTrainEval(Computation):
    def run(self, **dict_params):
        import os
        import sys
        import math
        import numpy as np
        from util import sizeof_fmt, pyxit_fit, pyxit_decision_function, pyxit_predict, print_cm
        from pyxit import PyxitClassifier
        from collections import namedtuple
        from clustertools import build_datacube
        from sklearn.svm import LinearSVC
        from dataset import ImageClassificationDataset
        from sklearn.ensemble import ExtraTreesClassifier
        from sklearn.utils import check_random_state
        from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix

        def get_best_params(cube):
            """Given a dataset cube extracts:
                - best min_size
                - best max_size
                - best colorspace
                - best accuracy
                - best accuracies
            """
            params = list()
            accuracies = list()
            for min_size, min_cube in cube.iter_dimensions("min_size"):
                for max_size, max_cube in min_cube.iter_dimensions("max_size"):
                    for colorspace, color_cube in max_cube.iter_dimensions("colorspace"):
                        accuracy = color_cube("accuracy")
                        if accuracy is not None:
                            params.append((float(min_size[0]), float(max_size[0]), int(colorspace[0])))
                            accuracies.append(accuracy)
            best = np.argmax(accuracies)
            best_params = params[best]
            return best_params[0], best_params[1], best_params[2], accuracies[best], accuracies

        params_type = namedtuple("params_type", dict_params.keys())
        params = params_type(**dict_params)
        print("Parameters: {}".format(params))

        cube = build_datacube(params.exp_name)
        cube = cube(dataset=params.dataset)

        print("Diagnose:")
        print(cube.diagnose())

        print("Look for best parameter set...")
        min_size, max_size, colorspace, best_acc, _ = get_best_params(cube)
        print("> min_size  : {}".format(min_size))
        print("> max_size  : {}".format(max_size))
        print("> colorspace: {}".format(colorspace))
        print("> accuracy  : {}".format(best_acc))

        print("Load dataset")
        print("Load dataset...")
        dataset_path = os.path.join(params.path, params.dataset)
        dataset = ImageClassificationDataset(dataset_path, dirs=[params.train, params.test])
        x_train, y_train = dataset.all(dirs=[params.train], one_hot=False)
        x_test, y_test = dataset.all(dirs=[params.test], one_hot=False)
        n_samples = x_train.shape[0]
        classes = np.unique(y_test)
        n_classes = classes.shape[0]

        print("Build model...")
        random_state = check_random_state(params.seed)
        target_size = 16
        et = ExtraTreesClassifier(
            n_estimators=params.n_estimators,
            min_samples_leaf=params.min_samples_leaf,
            max_features=(target_size * target_size) // 2,
            random_state=check_random_state(random_state.tomaxint() % (2 ** 32)),
            n_jobs=params.n_jobs
        )
        n_subwindows = int(math.ceil(params.total_subwindows / n_samples))
        pyxit = PyxitClassifier(
            base_estimator=et,
            n_subwindows=n_subwindows,
            min_size=min_size,
            max_size=max_size,
            target_height=target_size,
            target_width=target_size,
            n_jobs=params.n_jobs,
            colorspace=colorspace,
            random_state=check_random_state(random_state.tomaxint() % (2 ** 32))
        )
        svm = LinearSVC(C=params.c)

        print("Fit...")
        sys.stdout.flush()
        _x_train, _y_train = pyxit.extract_subwindows(x_train, y_train)
        print("_X_train: {} (size: {})".format(_x_train.shape, sizeof_fmt(_x_train.nbytes)))
        print("_y_train: {} (size: {})".format(_y_train.shape, sizeof_fmt(_y_train.nbytes)))

        pyxit_fit(pyxit, svm, x_train, y_train, _x_train, _y_train)

        print("Predict test...")
        if n_classes == 2:
            print("> binary problem")
            thresholds = pyxit_decision_function(pyxit, svm, x_test)
            y_pred = svm.classes_.take((thresholds >= 0).astype(np.int))
            print("> decision_function : {}".format(thresholds.shape))
        else:
            print("> multi-class problem")
            y_pred = pyxit_predict(pyxit, svm, x_test)
        print("> y_pred            : {}".format(y_pred.shape))

        to_return = {
            "accuracy": accuracy_score(y_test, y_pred),
            "cm": confusion_matrix(y_test, y_pred),
            "cv_best_score": best_acc,
            "min_size": min_size,
            "max_size": max_size,
            "colorspace": colorspace
        }

        print("Scores:")
        print("> accuracy : {}".format(to_return["accuracy"]))
        if n_classes == 2:
            label_params = {"labels": classes, "pos_label": classes[1]}
            to_return["precision"] = precision_score(y_test, y_pred, **label_params)
            to_return["recall"] = recall_score(y_test, y_pred, **label_params)
            to_return["f1-score"] = f1_score(y_test, y_pred, **label_params)
            y_test_bin = np.zeros(y_test.shape, dtype=np.int)
            y_test_bin[y_test == label_params["pos_label"]] = 1
            to_return["roc_auc"] = roc_auc_score(y_test_bin, thresholds)
            print("> precision: {}".format(to_return["precision"]))
            print("> recall   : {}".format(to_return["recall"]))
            print("> f1-score : {}".format(to_return["f1-score"]))
            print("> roc auc  : {}".format(to_return["roc_auc"]))

        print_cm(to_return["cm"], [str(i) for i in np.unique(y_test)])
        print(to_return)
        return to_return


if __name__ == "__main__":
    set_stdout_logging()
    param_set = ConstrainedParameterSet()
    param_set.add_parameters(dataset=[
        "ulg_lbtd2_chimio_necrose",
        "ulg_lbtd_lba",
        "ulb_anapath_lba",
        "ulg_lbtd_tissus",
        "cells_no_aug",
        "patterns_no_aug",
        "glomeruli_no_aug"
    ])
    param_set.add_parameters(test=["test", "val"])
    param_set.add_constraints(test_folder_constrain=infer_folder_name)

    n_jobs = 8
    param_set.add_parameters(
        path="[...]/datasets",
        seed=42, train="train", total_subwindows=1000000, exp_name="pyxit_cv",
        n_estimators=20, n_jobs=n_jobs, min_samples_leaf=1000, c=0.01
    )
    param_set.add_separator()
    param_set.add_parameters(dataset=["ulg_breast"])
    param_set.add_separator()
    param_set.add_parameters(dataset=["ulg_lbtd_lba_new"])
    experiment = Experiment("pyxit_train_eval", param_set, PyxitTrainEval)
    SlurmEnvironment(
        time="16-23",
        memory="63900M",
        partition="[...]",
        n_proc=n_jobs
    ).run(experiment)

    # InSituEnvironment(stdout=True).run(experiment)
    # for i, params in enumerate(param_set):
    #     PyxitTrainEval("pyxit_train_eval", str(i))(**params)
