from clustertools import Computation, ParameterSet, Experiment, ConstrainedParameterSet, set_stdout_logging
from clustertools.environment import Environment, SlurmEnvironment, InSituEnvironment, BashEnvironment


class PyxitCV(Computation):
    def run(self, **dict_params):
        """
        Method params:
        - n_estimators (20)
        - n_subwindows
        - window_height
        - window_width
        - colorspace (HSV, TRGB)
        - min_sample_leaf = 0.001
        - max_features = None (n_features / 2)
        """
        import os
        import sys
        import time
        import math
        from sklearn import clone
        from sklearn.svm import LinearSVC
        from sklearn.metrics import accuracy_score
        from sklearn.model_selection import GroupKFold
        from dataset import ImageClassificationDataset
        from pyxit import PyxitClassifier
        from sklearn.ensemble import ExtraTreesClassifier
        from sklearn.utils import check_random_state
        from collections import namedtuple
        from util import sizeof_fmt, extract_prefixes
        import numpy as np

        def window_indexes(n, idx, count):
            increment = np.tile(np.arange(count), (idx.shape[0], 1)).flatten()
            starts = np.repeat(np.arange(n)[idx], count)
            return starts * count + increment

        def fit(pyxit, svm, x, y, _x, _y):
            if _x is None or _y is None:
                _x, _y = pyxit.extract_subwindows(x, y)
            pyxit.fit(x, y, _X=_x, _y=_y)
            Xt = pyxit.transform(x, _X=_x, reset=True)
            svm.fit(Xt, y)

        def predict(pyxit, svm, x, _x):
            if _x is None:
                y = np.zeros(x.shape[0])
                _x, _ = pyxit.extract_subwindows(x, y)
            xt = pyxit.transform(x, _X=_x)
            return svm.predict(xt)

        params_type = namedtuple("params_type", dict_params.keys())
        params = params_type(**dict_params)
        print("Parameters: {}".format(params))

        print("Load dataset...")
        dataset_path = os.path.join(params.path, params.dataset)
        dataset = ImageClassificationDataset(dataset_path, dirs=[params.train])
        x_train, y_train = dataset.all(dirs=[params.train], one_hot=False)
        labels = extract_prefixes([os.path.basename(s) for s in x_train])
        n_samples = x_train.shape[0]

        print("> x_train: {}".format(x_train.shape))
        print("> y_train: {}".format(y_train.shape))

        print("Build model...")
        random_state = check_random_state(params.seed)
        target_size = 16
        et = ExtraTreesClassifier(
            n_estimators=params.n_estimators,
            min_samples_leaf=params.min_samples_leaf,
            max_features=(target_size * target_size) // 2,
            random_state=check_random_state(random_state.tomaxint() % (2 ** 32)),
            n_jobs=params.n_jobs
        )
        n_subwindows = int(math.ceil(params.total_subwindows / n_samples))
        pyxit = PyxitClassifier(
            base_estimator=et,
            n_subwindows=n_subwindows,
            min_size=params.min_size,
            max_size=params.max_size,
            target_height=target_size,
            target_width=target_size,
            n_jobs=params.n_jobs,
            colorspace=params.colorspace,
            random_state=check_random_state(random_state.tomaxint() % (2 ** 32))
        )
        svm = LinearSVC(C=params.c)

        print("Extract subwindows...")
        sys.stdout.flush()
        _x_train, _y_train = pyxit.extract_subwindows(x_train, y_train)

        print("_X_train: {} (size: {})".format(_x_train.shape, sizeof_fmt(_x_train.nbytes)))
        print("_y_train: {} (size: {})".format(_y_train.shape, sizeof_fmt(_y_train.nbytes)))

        # Fit
        print("Start cross-validation...")
        cv = GroupKFold(n_splits=params.n_splits)
        accuracies = np.zeros(params.n_splits)
        test_sizes = np.zeros(params.n_splits)

        # CV loop
        for i, (train, test) in enumerate(cv.split(x_train, y_train, labels)):
            print("Loop {}/{}... ".format(i + 1, params.n_splits))
            _pyxit = clone(pyxit)
            _svm = clone(svm)
            print("> train (size: train/test {}/{})...".format(train.shape[0], test.shape[0]))
            w_train = window_indexes(x_train.shape[0], train, _pyxit.n_subwindows)
            w_test = window_indexes(x_train.shape[0], test, _pyxit.n_subwindows)
            fit_start = time.time()
            fit(_pyxit, _svm, x_train[train], y_train[train], _x_train[w_train], _y_train[w_train])
            fit_time = time.time() - fit_start
            print("> predict...")
            pred_start = time.time()
            y_pred = predict(_pyxit, _svm, x_train[test], _x_train[w_test])
            pred_time = time.time() - pred_start
            accuracies[i] = accuracy_score(y_train[test], y_pred)
            test_sizes[i] = test.shape[0] / float(x_train.shape[0])
            print("> accuracy: {}".format(accuracies[i]))
            print("> time    : fit:{}s pred:{}s total:{}s".format(fit_time, pred_time, fit_time + pred_time))
            sys.stdout.flush()

        results = {
            "accuracies": accuracies,
            "test_sizes": test_sizes,
            "accuracy": np.sum(accuracies * test_sizes),
            "n_estimators": params.n_estimators,
            "min_samples_leaf": params.min_samples_leaf,
            "max_features": (target_size * target_size) / 2,
            "total_subwindows": params.total_subwindows,
            "n_subwindows": n_subwindows,
            "min_size": params.min_size,
            "max_size": params.max_size,
            "target_height": target_size,
            "colorspace": params.colorspace,
            "C": params.c
        }
        print(results)
        return results


if __name__ == "__main__":
    set_stdout_logging()
    param_set = ConstrainedParameterSet()
    param_set.add_parameters(dataset=[
          "ulg_lbtd2_chimio_necrose"
    ])
    param_set.add_parameters(
        colorspace=[1, 2]  # TRGB et HSV
    )
    param_set.add_parameters(
        min_size=[0.0, 0.25, 0.5, 0.75],
        max_size=[0.25, 0.5, 0.75, 1.0]
    )
    param_set.add_constraints(wsize_constrain=lambda min_size, max_size, **kwargs: min_size < max_size)

    n_jobs = 8
    param_set.add_parameters(
        path="[...]/datasets",
        seed=42, train="train", total_subwindows=1000000,
        n_estimators=20, n_jobs=n_jobs, min_samples_leaf=1000,
        n_splits=5, c=0.01
    )
    param_set.add_separator()
    param_set.add_parameters(dataset=[
        "ulg_lbtd_lba",
        "ulb_anapath_lba",
        "ulg_lbtd_tissus",
        "cells_no_aug",
        "patterns_no_aug",
        "glomeruli_no_aug"
    ])
    param_set.add_separator()
    param_set.add_parameters(dataset=[
        "ulg_breast"
    ])
    param_set.add_separator()
    param_set.add_parameters(dataset=[
        "ulg_lbtd_lba_new"
    ])
    experiment = Experiment("pyxit_cv", param_set, PyxitCV)
    SlurmEnvironment(
        time="16-23",
        memory="15900M",
        partition="[...]",
        n_proc=n_jobs
    ).run(experiment)
