import os
import sys
import numpy as np
import re

from keras.preprocessing.image import DirectoryIterator, img_to_array, array_to_img
from keras import backend as K
from PIL import Image as pil_image
from math import floor


def mk_features_filename(model, layer, dataset, folder=None, reduction=None, source="image_net"):
    name = ".npz"
    if folder is not None and len(folder) > 0:
        name = "_{}{}".format(folder, name)
    name = "{}{}".format(dataset, name)
    if source is not None and len(source) > 0:
        name = "{}_{}".format(source, name)
    if reduction is not None and len(reduction) > 0:
        name = "{}_{}".format(reduction, name)
    if layer is not None and len(layer) > 0:
        name = "{}_{}".format(layer, name)
    return "{}_{}".format(model, name)


def load_features_of(path, model, dataset, folder=None, reduction=None, layer=None, source=None):
    filename = mk_features_filename(model, layer, dataset, folder, reduction=reduction, source=source)
    return load_features(os.path.join(path, filename))


def mk_task_filename(task, model, layer, dataset, folder=None, reduction=None, source="image_net"):
    filename = mk_features_filename(model, layer, dataset, folder=folder, reduction=reduction, source=source)
    name, ext = filename.split(".", 1)
    return "{}_{}.{}".format(name, task, ext)


def sizeof_fmt(num, suffix='B'):
    for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
        if abs(num) < 1024.0:
            return "%3.1f%s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f%s%s" % (num, 'Yi', suffix)


def load_features(filepath):
    with open(filepath, "rb") as file:
        d = np.load(file)
        return d["x_names"], d["y"], d["x_trans"]


def extract_prefixes(data, sep="_"):
    return np.array([int(s.split(sep, 1)[0]) for s in data])


def print_cm(cm, labels, hide_zeroes=False, hide_diagonal=False, hide_threshold=None, out=sys.stdout):
    """pretty print for confusion matrices"""
    columnwidth = max([len(x) for x in labels]+[5])  # 5 is value length
    empty_cell = " " * columnwidth
    # Print header
    out.write("    " + empty_cell)
    for label in labels:
        out.write(" %{0}s".format(columnwidth) % label)
    out.write(os.linesep)
    # Print rows
    for i, label1 in enumerate(labels):
        out.write("    %{0}s ".format(columnwidth) % label1)
        for j in range(len(labels)):
            cell = "%{0}.1f ".format(columnwidth) % cm[i, j]
            if hide_zeroes:
                cell = cell if float(cm[i, j]) != 0 else empty_cell
            if hide_diagonal:
                cell = cell if i != j else empty_cell
            if hide_threshold:
                cell = cell if cm[i, j] > hide_threshold else empty_cell
            out.write(cell)
        out.write(os.linesep)


def get_features_files(path, dataset, folders, reduction=None, layer=None, source=None):
    files = os.listdir(path)
    regex = r"[_a-z0-9]+(_{})?(_{})?(_{})?_({}).npz".format(layer, reduction, source, dataset, "|".join(folders))
    matching = [file for file in files if re.match(regex, file) is not None]
    return [os.path.join(path, file) for file in matching]


def get_matching_feature_files(path, **fields):
    files = os.listdir(path)
    item_regex = r"[_a-z0-9]+"

    folder = item_regex
    if "folder" in fields:
        if isinstance(fields["folder"], list):
            folder = "|".join(fields["folder"])
        else:
            folder = fields["folder"]

    none_or_parenth = (lambda s: "(" + s + ")" if s is not None else s)
    regex = mk_features_filename(
        model="(" + fields.get("model", item_regex) + ")",
        layer=none_or_parenth(fields.get("layer", item_regex)),
        dataset="(" + fields.get("dataset", item_regex) + ")",
        folder="(" + folder + ")",
        reduction=none_or_parenth(fields.get("reduction", item_regex)),
        source=none_or_parenth(fields.get("source", item_regex))
    )

    matches = [(file, re.match(regex, file)) for file in files]
    groups = [(1, "model"), (2, "layer"), (3, "dataset"), (4, "folder"), (5, "reduction"), (6, "source")]
    return [
        (os.path.join(path, file), {n: match.group(i) for i, n in groups if i <= len(match.groups())})
        for file, match in matches
        if match is not None
    ]


def get_all_features_files(path, model, dataset, folders, reduction=None, layer=None, source=None):
    files = os.listdir(path)
    regex = r"{}_([_a-z0-9]+)(_{})?(_{})?_({}).npz".format(model, reduction, source, dataset, "|".join(folders))
    matches = [re.match(regex, file) for file in files]
    return [(os.path.join(path, file), match.group(1)) for file, match in zip(files, matches) if match is not None]


def load_merged_features(path, **fields):
    """Merge features matching all the given fields. The missing fields are matched whatever their value.
    Information about the missing fields are returned as fourth returned value
    """
    matching_data = get_matching_feature_files(path, **fields)
    if len(matching_data) == 0:
        raise ValueError("No matching file found in path '{}' (dataset: {})".format(path, fields["dataset"]))

    # sort to make sure features are ordered the same way across selection
    matching_data = sorted(matching_data, key=lambda item: item[0])

    # load first item to bootstrap lists
    files, data_dict = zip(*matching_data)
    missing_fields = set(data_dict[0].keys()).difference(fields.keys())

    print("> matched files: {}".format([os.path.basename(file) for file in files]))
    print("> fields       : {}".format(fields))
    print("> miss. fields : {}".format(missing_fields))

    x_names, y, x_trans = load_features(files[0])
    # sort name to make sure order is preserved between sets
    sorted_idx = np.argsort(x_names)
    x_names, y, x_trans = x_names[sorted_idx], y[sorted_idx], x_trans[sorted_idx, :]
    # save features data
    features = [([(k, v) for k, v in data_dict[0].items() if k in missing_fields], i) for i in range(x_trans.shape[1])]

    for file, metadata in zip(files[1:], data_dict[1:]):
        other_names, other_y, other_x = load_features(file)
        other_sorted_idx = np.argsort(other_names)
        other_names, other_y, other_x = other_names[other_sorted_idx], other_y[other_sorted_idx], other_x[other_sorted_idx, :]
        name_diff = np.logical_not(np.equal(x_names, other_names))
        y_diff = np.logical_not(np.equal(y, other_y))
        if np.any(name_diff) or np.any(y_diff):
            raise ValueError("mismatch between feature files '{}' and '{}' ({} and {} -> diff:{}) ({} and {})".format(
                os.path.basename(files[0]),
                os.path.basename(file),
                x_names, other_names, np.setdiff1d(x_names, other_names),
                y, other_y
            ))
        x_trans = np.hstack((x_trans, other_x))
        features.extend([([(k, v) for k, v in metadata.items() if k in missing_fields], i) for i in range(other_x.shape[1])])
    return x_names, y, x_trans, features


def infer_folder_name(dataset, test, **kwargs):
    names = {"ulg_lbtd2_chimio_necrose", "ulg_lbtd_lba", "ulb_anapath_lba", "ulg_lbtd_tissus"}
    return (dataset in names and test == "val") or (dataset not in names and test == "test")


def pyxit_fit(pyxit, svm, x, y, _x=None, _y=None):
    if _x is None or _y is None:
        _x, _y = pyxit.extract_subwindows(x, y)
    pyxit.fit(x, y, _X=_x, _y=_y)
    Xt = pyxit.transform(x, _X=_x, reset=True)
    svm.fit(Xt, y)


def pyxit_predict(pyxit, svm, x, _x=None):
    if _x is None:
        y = np.zeros(x.shape[0])
        _x, _ = pyxit.extract_subwindows(x, y)
    xt = pyxit.transform(x, _X=_x)
    return svm.predict(xt)


def pyxit_decision_function(pyxit, svm, x, _x=None):
    if _x is None:
        y = np.zeros(x.shape[0])
        _x, _ = pyxit.extract_subwindows(x, y)
    xt = pyxit.transform(x, _X=_x)
    return svm.decision_function(xt)


def load_crop_img(path, load_size, crop_size, grayscale=False, random=True):
    """Loads an image into PIL format and crop it at a given size.

    # Arguments
        path: Path to image file
        load_size: int or (int, int),
            If integer, size the smallest side should be resized to. Otherwise the (height, width) to which the
            image should be resized to
        crop_size: int or (int, int),
            Size of the random square crop to be taken in the image (should be less then or equal to load_size)
            If a tuple is given the random crop is a rectangle of given (height, width)
        grayscale: Boolean, whether to load the image as grayscale.

    # Returns
        A PIL Image instance.

    # Raises
        ImportError: if PIL is not available.
    """
    if pil_image is None:
        raise ImportError('Could not import PIL.Image. '
                          'The use of `array_to_img` requires PIL.')
    img = pil_image.open(path)
    if grayscale:
        if img.mode != 'L':
            img = img.convert('L')
    else:
        if img.mode != 'RGB':
            img = img.convert('RGB')

    # load resize
    width, height = img.size
    if isinstance(load_size, tuple):
        new_height, new_width = load_size
    elif height < width:
        ratio = float(load_size) / height
        new_width, new_height = int(floor(ratio * width)), load_size
    elif width < height:
        ratio = float(load_size) / width
        new_width, new_height = load_size, int(floor(ratio * height))
    else:
        new_height, new_width = load_size, load_size

    img = img.resize((new_width, new_height))

    # (random) crop resize
    if isinstance(crop_size, tuple):
        crop_size_h, crop_size_w = crop_size
    else:
        crop_size_h, crop_size_w = crop_size, crop_size
    width, height = img.size
    len_crop_h, len_crop_w = height - crop_size_h, width - crop_size_w
    if random:
        offset_h = np.random.randint(len_crop_h + 1)
        offset_w = np.random.randint(len_crop_w + 1)
    else:
        offset_h = int(len_crop_h / 2)
        offset_w = int(len_crop_w / 2)
    return img.crop((offset_w, offset_h, offset_w + crop_size_w, offset_h + crop_size_h))


class ConservativeDirectoryIterator(DirectoryIterator):
    """DirectoryIterator performing random or deterministic rescale and crop on images.
    This iterator can be used to avoid resizing images without taking care about the width/height ratio.
    Let's denote load_size_range and crop_size resp as (l0, l1), 0 < l0 < l1 and c <= l1, l0 (l1 excluded)
    To maintain width/height ratio, the extraction of a crop is done as follows:
        - resize the input image so that the smallest dimension has a size taken randomly in the
          interval [l0, l1[
        - extract the crop:
            * randomly, sampling uniformly among all the valid crops that lie in the resized image
            * deterministically, taking the crop at the center of the resized image
    Therefore for deterministic behavior, should set:
        - random_crop: False
        - load_size_range: (a, a) with a > 0
    """
    def __init__(self, directory, image_data_generator, random_crop=True, balance_classes=False, load_size_range=None, **kwargs):
        if isinstance(kwargs["target_size"], tuple):
            raise ValueError("target size fo 'ConservativeDirectoryIterator' should be one integer.")
        kwargs["shuffle"] = kwargs["shuffle"] and not balance_classes
        kwargs["target_size"] = (kwargs["target_size"], kwargs["target_size"])
        super(ConservativeDirectoryIterator, self).__init__(directory, image_data_generator, **kwargs)
        self._load_size_range = load_size_range
        self._balance_classes = balance_classes
        self._random_crop = random_crop

    def _get_batches_of_transformed_samples(self, index_array):
        """For python 2.x.

        # Returns
            The next batch.
        """
        batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx())
        grayscale = self.color_mode == 'grayscale'
        # build batch of image data
        for i, j in enumerate(index_array):
            fname = self.filenames[j]
            load_size = np.random.randint(self._load_size_range[0], self._load_size_range[1] + 1)
            img = load_crop_img(os.path.join(self.directory, fname),
                                load_size=load_size,
                                crop_size=self.target_size,
                                random=self._random_crop,
                                grayscale=grayscale)
            x = img_to_array(img, data_format=self.data_format)
            x = self.image_data_generator.random_transform(x)
            x = self.image_data_generator.standardize(x)
            batch_x[i] = x
        # optionally save augmented images to disk for debugging purposes
        if self.save_to_dir:
            for i, j in enumerate(index_array):
                img = array_to_img(batch_x[i], self.data_format, scale=True)
                fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
                                                                  index=j,
                                                                  hash=np.random.randint(int(1e7)),
                                                                  format=self.save_format)
                img.save(os.path.join(self.save_to_dir, fname))
        # build batch of labels
        if self.class_mode == 'input':
            batch_y = batch_x.copy()
        elif self.class_mode == 'sparse':
            batch_y = self.classes[index_array]
        elif self.class_mode == 'binary':
            batch_y = self.classes[index_array].astype(K.floatx())
        elif self.class_mode == 'categorical':
            batch_y = np.zeros((len(batch_x), self.num_classes), dtype=K.floatx())
            for i, label in enumerate(self.classes[index_array]):
                batch_y[i, label] = 1.
        else:
            return batch_x
        return batch_x, batch_y

    def _flow_index(self):
        # Ensure self.batch_index is 0.
        self.reset()
        class_indexes = dict()
        classes = np.unique(self.classes)
        for cls in classes:
            class_indexes[cls] = np.where(self.classes == cls)[0]
        while 1:
            if self.seed is not None:
                np.random.seed(self.seed + self.total_batches_seen)
            if self.batch_index == 0:
                if not self.shuffle and self._balance_classes:
                    selected_classes = np.random.choice(classes, size=self.n)
                    un_classes = np.unique(selected_classes)
                    for cls in un_classes:
                        selected_idx = selected_classes == cls
                        self.index_array[selected_idx] = np.random.choice(
                            class_indexes[cls],
                            size=np.count_nonzero(selected_idx)
                        )
                else:
                    self._set_index_array()

            current_index = (self.batch_index * self.batch_size) % self.n
            if self.n > current_index + self.batch_size:
                self.batch_index += 1
            else:
                self.batch_index = 0
            self.total_batches_seen += 1
            yield self.index_array[current_index:
                                   current_index + self.batch_size]


def conservative_flow_from_directory(image_data_generator, directory, load_size_range=(224, 225), target_size=224,
                                     color_mode='rgb', classes=None, class_mode='categorical', batch_size=32, shuffle=False,
                                     seed=None, save_to_dir=None, save_prefix='', save_format='png', follow_links=False,
                                     balance_classes=False, random_crop=True):
    """Return a ImageNetDirectoryIterator iterating over a directory. By default deterministic."""
    return ConservativeDirectoryIterator(
        directory, image_data_generator,
        load_size_range=load_size_range,
        target_size=target_size, color_mode=color_mode,
        classes=classes, class_mode=class_mode,
        data_format=image_data_generator.data_format,
        batch_size=batch_size, shuffle=shuffle, seed=seed,
        save_to_dir=save_to_dir,
        save_prefix=save_prefix,
        save_format=save_format,
        follow_links=follow_links,
        balance_classes=balance_classes,
        random_crop=random_crop
    )


if __name__ == "__main__":
    names, y, x, f = load_merged_features(path="[...]/features/", layer=None, folder="train", reduction=None, dataset="glomeruli_no_aug")
    print(x.shape)

    cnt = {}
    for l, i in f:
        p, n = l[0]
        cnt[(p, n)] = cnt.get((p, n), 0) + 1
    print(cnt)