from clustertools import Experiment, ConstrainedParameterSet, set_stdout_logging, Computation
from clustertools.environment import SlurmEnvironment


class SVMPerLayer(Computation):
    def run(self, **dict_params):
        import os
        import numpy as np
        from collections import namedtuple
        from sklearn.metrics import accuracy_score, make_scorer
        from sklearn.metrics import confusion_matrix
        from sklearn.metrics import f1_score
        from sklearn.metrics import precision_score
        from sklearn.metrics import recall_score
        from sklearn.metrics import roc_auc_score
        from sklearn.model_selection import GroupKFold, GridSearchCV
        from sklearn.preprocessing import LabelEncoder
        from sklearn.svm import LinearSVC
        from util import load_features_of, extract_prefixes, sizeof_fmt, print_cm, mk_task_filename

        params_type = namedtuple("params_type", dict_params.keys())
        params = params_type(**dict_params)
        print("Parameters: {}".format(params))

        dest = os.path.join(params.dest, params.model)
        if not os.path.exists(dest):
            os.makedirs(dest)

        path = os.path.join(params.path, params.model)

        print("Load features")
        print("> train")
        x_train_names, y_train, x_train_trans = load_features_of(
            path=path,
            model=params.model,
            dataset=params.dataset,
            folder=params.train,
            layer=params.layer,
            reduction=params.reduction
        )
        train_labels = extract_prefixes([os.path.basename(s) for s in x_train_names])
        print("> test")
        x_test_names, y_test, x_test_trans = load_features_of(
            path=path,
            model=params.model,
            dataset=params.dataset,
            folder=params.test,
            layer=params.layer,
            reduction=params.reduction
        )

        print("Encode classes")
        encoder = LabelEncoder()
        encoder.fit(y_train)
        n_classes = encoder.classes_.shape[0]
        print("> classes: {}".format(encoder.classes_))
        print("> mapped : {}".format(np.arange(n_classes)))

        print("Loaded data")
        print("> x_train_names : {} (size: {})".format(x_train_names.shape, sizeof_fmt(x_train_names.nbytes)))
        print("> y_train       : {} (size: {})".format(y_train.shape, sizeof_fmt(y_train.nbytes)))
        print("> x_train_trans : {} (size: {})".format(x_train_trans.shape, sizeof_fmt(x_train_trans.nbytes)))
        print("> train_labels  : {} (size: {})".format(train_labels.shape, sizeof_fmt(train_labels.nbytes)))
        print("> x_test_names  : {} (size: {})".format(x_test_names.shape, sizeof_fmt(x_test_names.nbytes)))
        print("> y_test        : {} (size: {})".format(y_test.shape, sizeof_fmt(y_test.nbytes)))
        print("> x_test_trans  : {} (size: {})".format(x_test_trans.shape, sizeof_fmt(x_test_trans.nbytes)))

        print("Prepare cross-validation")
        n_samples, n_features = x_train_trans.shape
        unique_labels = np.unique(train_labels)
        n_splits = min(10, unique_labels.shape[0])
        kfold = GroupKFold(n_splits=n_splits)
        print("> n_splits: {}".format(n_splits))
        grid = {"C": [1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0, 10.0]}
        print("> grid: {}".format(grid))

        grid_search = GridSearchCV(
            LinearSVC(dual=n_samples <= n_features),
            grid, cv=kfold, scoring=make_scorer(accuracy_score),
            verbose=10, n_jobs=params.n_jobs, refit=True
        )
        print("Run cross-validation")
        grid_search.fit(x_train_trans, encoder.transform(y_train), train_labels)
        print(" -> best_params: {}".format(grid_search.best_params_))
        print(" -> best_score: {}".format(grid_search.best_score_))
        print(" -> results: {}".format(grid_search.cv_results_))

        best_model = grid_search.best_estimator_

        print("Predict test...")
        if n_classes == 2:
            print("> binary problem")
            try:
                probas = best_model.predict_proba(x_test_trans)
                thresholds = probas[:, 1]
                y_pred = np.argmax(probas, axis=1).astype(np.int)
            except (AttributeError, NotImplementedError):
                thresholds = best_model.decision_function(x_test_trans)
                y_pred = (thresholds >= 0).astype(np.int)
            print("> decision_function : {}".format(thresholds.shape))
        else:
            print("> multi-class problem")
            y_pred = best_model.predict(x_test_trans)
        y_pred = encoder.inverse_transform(y_pred)
        print("> y_pred            : {}".format(y_pred.shape))

        label_params = {"labels": encoder.classes_, "pos_label": encoder.classes_[1]}
        to_return = {
            "accuracy": accuracy_score(y_test, y_pred),
            "cm": confusion_matrix(y_test, y_pred),
            "cv_best_score": float(grid_search.best_score_),
        }

        for k, v in grid_search.best_params_.items():
            to_return[k] = v

        print("Scores:")
        print("> accuracy : {}".format(to_return["accuracy"]))
        if n_classes == 2:
            to_return["precision"] = precision_score(y_test, y_pred, **label_params)
            to_return["recall"] = recall_score(y_test, y_pred, **label_params)
            to_return["f1-score"] = f1_score(y_test, y_pred, **label_params)
            to_return["roc_auc"] = roc_auc_score(encoder.transform(y_test), thresholds)
            print("> precision: {}".format(to_return["precision"]))
            print("> recall   : {}".format(to_return["recall"]))
            print("> f1-score : {}".format(to_return["f1-score"]))
            print("> roc auc  : {}".format(to_return["roc_auc"]))

        print_cm(to_return["cm"], [str(i) for i in np.unique(y_test)])

        # save results
        filename = mk_task_filename(
            task="svm",
            model=params.model,
            layer=params.layer,
            dataset=params.dataset,
            folder=None,
            reduction=params.reduction,
            source=params.source
        )
        filepath = os.path.join(dest, filename)
        print("> save results to '{}'".format(filepath))
        np.savez(filepath, **to_return)
        return to_return


dense_layers = [
    "pool1", "conv2_block6_concat", "pool2_pool", "conv3_block12_concat", "pool3_pool", "conv4_block48_concat",
    "pool4_pool", "conv5_block32_concat", "bn"
]

incep_v3_layers = [
    "max_pooling2d_2", "mixed0", "mixed1", "mixed2", "mixed3", "mixed4", "mixed5", "mixed6", "mixed7", "mixed8",
    "mixed9", "mixed10"
]

incep_v2_layers = [
    "max_pooling2d_2", "mixed_5b", "block35_1_ac", "block35_4_ac", "block35_7_ac", "block35_10_ac", "mixed_6a",
    "block17_5_ac", "block17_10_ac", "block17_15_ac", "block17_20_ac", "mixed_7a", "block8_3_ac", "block8_6_ac",
    "block8_9_ac", "conv_7b_ac"
]

resnet50_layers = [
    "activation_1", "activation_4", "activation_7", "activation_10", "activation_13", "activation_16",
    "activation_19", "activation_22", "activation_25", "activation_28", "activation_31", "activation_34",
    "activation_37", "activation_40", "activation_43", "activation_46", "activation_49"
]


def infer_layer_name(model, layer, **kwargs):
    return (model == "inception_v3" and layer in set(incep_v3_layers)) \
            or (model == "inception_resnet_v2" and layer in set(incep_v2_layers))\
            or (model == "dense_net_201" and layer in set(dense_layers))\
            or (model == "resnet50" and layer in set(resnet50_layers))


def infer_folder_name(dataset, test, **kwargs):
    names = {"ulg_lbtd2_chimio_necrose", "ulg_lbtd_lba", "ulb_anapath_lba", "ulg_lbtd_tissus"}
    return (dataset in names and test == "val") or (dataset not in names and test == "test")


if __name__ == "__main__":
    set_stdout_logging()
    param_set = ConstrainedParameterSet()
    param_set.add_parameters(dataset=[
        "cells_no_aug",
        "patterns_no_aug",
        "glomeruli_no_aug",
        "ulg_lbtd2_chimio_necrose",
        "ulg_lbtd_lba",
        "ulb_anapath_lba",
        "ulg_lbtd_tissus",
        "ulg_breast"
    ])
    param_set.add_parameters(model=[
        "resnet50",
        "dense_net_201",
        "inception_resnet_v2",
        "inception_v3"
    ])
    param_set.add_parameters(layer=incep_v2_layers + incep_v3_layers + dense_layers + resnet50_layers)
    param_set.add_parameters(method=["svm"])
    param_set.add_parameters(test=["test", "val"])
    param_set.add_constraints(test_folder_constrain=infer_folder_name)
    param_set.add_constraints(layer_constrain=infer_layer_name)

    param_set.add_parameters(
        path="[...]/features/per_layer",
        dest="[...]/features/per_layer/results",
        train="train", n_estimators=1000,
        source="image_net", reduction=None, n_jobs=1
    )
    param_set.add_separator()
    param_set.add_parameters(dataset=["ulg_lbtd_lba_new"])
    experiment = Experiment("svm_per_layer", param_set, SVMPerLayer)
    SlurmEnvironment(
        time="8-8",
        memory="3900M",
        partition="[...]",
        n_proc=1
    ).run(experiment)
