from __future__ import print_function
import os
import numpy as np
from argparse import ArgumentParser
import re
import sys

import time
from dataset import ImageClassificationDataset
from keras import Model
from keras.applications.nasnet import NASNetLarge, preprocess_input as nasnet_preprocess
from keras.applications.nasnet import NASNetMobile
from keras.applications.densenet import DenseNet, preprocess_input as densenet_preprocess
from keras.applications.resnet50 import ResNet50, preprocess_input as resnet50_preprocess
from keras.applications.inception_v3 import InceptionV3, preprocess_input as inceptionv3_preprocess
from keras.applications.vgg16 import VGG16, preprocess_input as vgg16_preprocess
from keras.applications.vgg19 import VGG19, preprocess_input as vgg19_preprocess
from keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input as inceptionresnetv2_preprocess
from keras.applications.mobilenet import MobileNet, preprocess_input as mobilenet_preprocess
from keras.layers import GlobalAveragePooling2D, Lambda, Concatenate

from util import mk_features_filename, load_crop_img
import tensorflow as tf


def cytonet_preprocess_input(x, **kwargs):
    """Preprocesses a tensor encoding a batch of images.

    # Arguments
        x: input Numpy tensor, 4D.
        data_format: data format of the image tensor.

    # Returns
        Preprocessed tensor.
    """
    # 'RGB'->'BGR'
    x = x.astype(np.float64)
    x = x[:, :, :, ::-1]
    # Zero-center by mean pixel
    x[:, :, :, 0] -= 161.37053462
    x[:, :, :, 1] -= 110.10437749
    x[:, :, :, 2] -= 141.10102787
    return x


def crop_load_all(filenames, load_size_range=(225, 325), crop_size=224, crops_per_image=1, grayscale=False, random_crop=False):
    n_images = len(filenames)
    n_channels = 3 if not grayscale else 1
    images = np.zeros((n_images * crops_per_image, crop_size, crop_size, n_channels), dtype=np.uint8)
    for i, filename in enumerate(filenames):
        for j in range(crops_per_image):
            load_size = np.random.randint(load_size_range[0], load_size_range[1] + 1)
            images[i * crops_per_image + j, :, :, :] = load_crop_img(
                filename,
                load_size=load_size,
                crop_size=crop_size,
                grayscale=grayscale,
                random=random_crop
            )
    return images


def batch_transform(loader, features, x, batch_size=96):
    n_samples = x.shape[0]
    ratio = int(n_samples / batch_size)
    n_batches = int(ratio) if n_samples % batch_size == 0 else (int(ratio) + 1)
    n_features = np.prod(features._model.layers[-1].output_shape[1:])
    x_trans = np.zeros((x.shape[0], n_features), dtype=np.float32)
    for i in range(n_batches):
        start = i * batch_size
        end = min(x.shape[0], start + batch_size)
        imgs = loader.transform(x[start:end])
        x_trans[start:end, :] = features.transform(imgs)
        del imgs
        print("\r{}/{}".format(start + 1, x.shape[0]), end="")
        sys.stdout.flush()
    print()
    return x_trans


class ImageLoader(object):
    def __init__(self, load_size_range=(225, 325), crop_size=224, n_crops_per_image=1, random_crop=False):
        self._load_size_range = load_size_range
        self._crop_size = crop_size
        self._n_crops_per_image = n_crops_per_image
        self._random_crop = random_crop

    def fit(self, X, y=None, **fit_params):
        return self

    def transform(self, X):
        return crop_load_all(
            X, load_size_range=self._load_size_range,
            crop_size=self._crop_size,
            crops_per_image=self._n_crops_per_image,
            random_crop=self._random_crop
        )


MODEL_RESNET50 = "resnet50"
MODEL_VGG19 = "vgg19"
MODEL_VGG16 = "vgg16"
MODEL_INCEPTION_V3 = "inception_v3"
MODEL_INCEPTION_RESNET_V2 = "inception_resnet_v2"
MODEL_MOBILE = "mobile"
MODEL_DENSE_NET_201 = "dense_net_201"
MODEL_NASNET_LARGE = "nas_net_large"
MODEL_NASNET_MOBILE = "nas_net_mobile"


class PretrainedModelFeatures(object):
    def __init__(self, model=MODEL_RESNET50, layer="last", reduction="avg", weights="imagenet", filters=None):
        self._model_name = model
        self._weights = weights
        self._layer = layer
        self._reduction = reduction
        self._filters = filters
        self._model = self._get_model(
            self._model_name,
            layer=self._layer,
            reduction=self._reduction,
            weights=self._weights,
            filters=self._filters
        )
        self._forward_times = list()
        self._preproc_times = list()

    @property
    def forward_times(self):
        return self._forward_times

    @property
    def preproc_times(self):
        return self._preproc_times

    def fit(self, X, y=None, **fit_params):
        return self

    def transform(self, X):
        start = time.time()
        X = self._get_preprocessing(self._model_name)(X.astype(np.float))
        self._preproc_times.append(time.time() - start)
        start = time.time()
        features = self._model.predict(X)
        self._forward_times.append(time.time() - start)
        return features.reshape((X.shape[0], -1))

    def __setstate__(self, state):
        self.__dict__ = state
        self._model = self._get_model(
            self._model_name,
            layer=self._layer,
            reduction=self._reduction,
            weights=self._weights
        )

    def __getstate__(self):
        self._model = None
        return self.__dict__

    @classmethod
    def _get_model(cls, model_name=None, layer="last", reduction="avg", weights="imagenet", filters=None):
        input_shape = cls._get_input_shape(model_name)
        if model_name == MODEL_INCEPTION_V3:
            model = InceptionV3(input_shape=input_shape, include_top=False, weights="imagenet")
        elif model_name == MODEL_RESNET50:
            model = ResNet50(input_shape=input_shape, include_top=False, weights="imagenet")
        elif model_name == MODEL_INCEPTION_RESNET_V2:
            model = InceptionResNetV2(input_shape=input_shape, include_top=False, weights="imagenet")
        elif model_name == MODEL_MOBILE:
            model = MobileNet(input_shape=input_shape, include_top=False, weights="imagenet")
        elif model_name == MODEL_VGG16:
            model = VGG16(input_shape=input_shape, include_top=False, weights="imagenet")
        elif model_name == MODEL_VGG19:
            model = VGG19(input_shape=input_shape, include_top=False, weights="imagenet")
        elif model_name == MODEL_DENSE_NET_201:
            blcks = [6, 12, 48, 32]
            model = DenseNet(blocks=blcks, input_shape=input_shape, include_top=False, weights="imagenet")
        elif model_name == MODEL_NASNET_LARGE:
            model = NASNetLarge(input_shape=input_shape, include_top=False, weights="imagenet")
        elif model_name == MODEL_NASNET_MOBILE:
            model = NASNetMobile(input_shape=input_shape, include_top=False, weights="imagenet")
        else:
            raise ValueError("Error: no such model '{}'...".format(model_name))
        layer_by_name = {layer.name: layer for layer in model.layers}
        if layer != "last":
            model = Model(inputs=model.inputs, outputs=[layer_by_name[layer].output])
        if filters is not None:
            slices = list()
            for filter in filters:
                slices.append((lambda s: Lambda(
                    function=lambda t: tf.expand_dims(t[:, :, :, s], axis=3),
                    output_shape=list(model.output_shape[:3]) + [1]
                )(model.output))(filter))
            model = Model(inputs=model.inputs, outputs=[Concatenate(axis=-1)(slices)])
        if reduction == "avg":
            glb_avg = GlobalAveragePooling2D()(model.output)
            model = Model(inputs=model.inputs, outputs=[glb_avg])
        if weights is not None and weights != "imagenet":
            model.load_weights(weights, by_name=True)
        return model

    @staticmethod
    def _get_preprocessing(model_name=None):
        if model_name == MODEL_INCEPTION_V3:
            return inceptionv3_preprocess
        elif model_name == MODEL_RESNET50:
            return resnet50_preprocess
        elif model_name == MODEL_INCEPTION_RESNET_V2:
            return inceptionresnetv2_preprocess
        elif model_name == MODEL_MOBILE:
            return mobilenet_preprocess
        elif model_name == MODEL_VGG16:
            return lambda x: vgg16_preprocess(x, mode="tf")
        elif model_name == MODEL_VGG19:
            return lambda x: vgg19_preprocess(x, mode="tf")
        elif model_name == MODEL_DENSE_NET_201:
            return densenet_preprocess
        elif model_name == MODEL_NASNET_LARGE or model_name == MODEL_NASNET_MOBILE:
            return nasnet_preprocess
        else:
            raise ValueError("Error: no such model '{}'...".format(model_name))

    @staticmethod
    def _get_input_shape(model_name=None):
        if model_name == MODEL_INCEPTION_V3:
            return 299, 299, 3
        elif model_name == MODEL_RESNET50:
            return 224, 224, 3
        elif model_name == MODEL_INCEPTION_RESNET_V2:
            return 299, 299, 3
        elif model_name == MODEL_MOBILE:
            return 224, 224, 3
        elif model_name == MODEL_VGG16:
            return 224, 224, 3
        elif model_name == MODEL_VGG19:
            return 224, 224, 3
        elif model_name == MODEL_DENSE_NET_201:
            return 224, 224, 3
        elif model_name == MODEL_NASNET_LARGE:
            return 331, 331, 3
        elif model_name == MODEL_NASNET_MOBILE:
            return 224, 224, 3
        else:
            raise ValueError("Error: no such model '{}'...".format(model_name))


def transform_data(dataset, loader, features, folder):
    x, y = dataset.all(dirs=[folder])
    x_trans = batch_transform(loader, features, x, batch_size=128)
    print("> folder '{}': {} -> {}".format(folder, len(x), x_trans.shape))
    x_names = np.array([os.path.basename(f) for f in x], dtype=np.object)
    return x_names, y, x_trans


def main(argv):
    """
    Save files

    init_filenames
    classes
    features
    """
    parser = ArgumentParser()

    # Model loading and saving
    parser.add_argument("--path", dest="path")
    parser.add_argument("--model", dest="model")
    parser.add_argument("--dataset", dest="dataset")
    parser.add_argument("--folders", dest="folders", type=str)
    parser.add_argument("--layer", dest="layer", default="last")
    parser.add_argument("--source", dest="source", default=None)
    parser.add_argument("--weights", dest="weights", default="imagenet")
    parser.add_argument("--reduction", dest="reduction", default=None)
    parser.add_argument("--random_crop", dest="random_crop", action="store_true")
    parser.add_argument("--dest", dest="dest")
    parser.set_defaults(random_crop=False)
    params, unknown = parser.parse_known_args(argv)
    params.folders = re.split("\s*,\s*", params.folders.strip(" "))
    print("Parameters: {}".format(params))

    if not os.path.exists(params.dest):
        os.makedirs(params.dest)

    dataset = ImageClassificationDataset(os.path.join(params.path, params.dataset), dirs=params.folders)
    features = PretrainedModelFeatures(model=params.model, layer=params.layer, reduction="avg", weights=params.weights)
    input_shape = features._get_input_shape(params.model)
    loader = ImageLoader(
        load_size_range=(input_shape[0], input_shape[0]),
        crop_size=input_shape[0],
        random_crop=params.random_crop  # ideally should be False
    )

    for folder in params.folders:
        x_names, y, x_trans = transform_data(dataset, loader, features, folder)
        filename = mk_features_filename(
            model=params.model,
            layer=params.layer,
            dataset=params.dataset,
            folder=folder,
            reduction=params.reduction,
            source=params.source
        )
        filepath = os.path.join(params.dest, filename)
        with open(filepath, "wb+") as file:
            np.savez(file, x_names=x_names, y=y, x_trans=x_trans)

    print("Times:")
    forward = features.forward_times
    preproc = features.preproc_times
    both = np.array(forward) + np.array(preproc)
    print("Forward: t:{:3.4f}s m:{:3.4f}s +- {:3.4f}s".format(np.sum(forward), np.mean(forward), np.std(forward)))
    print("Preproc: t:{:3.4f}s m:{:3.4f}s +- {:3.4f}s".format(np.sum(preproc), np.mean(preproc), np.std(preproc)))
    print("Both   : t:{:3.4f}s m:{:3.4f}s +- {:3.4f}s".format(np.sum(both), np.mean(both), np.std(both)))


if __name__ == "__main__":
    main(sys.argv[1:])
