import os
import sys
import keras
import numpy as np

from keras import backend as K
from keras import layers, models, optimizers


from keras.models import Model
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, CSVLogger, TensorBoard
from keras.layers import Dense, Flatten, Input, Activation, BatchNormalization, Conv2D, Lambda, Reshape


def HitNet(input_shape, num_classes, dim_capsule=16):
	"""
	HitNet classification network as described in the paper. This is a basic instance of a network that uses the HoM Layer and the
	centripetal loss. 
	It is composed of three subnets: 
		- The feature extractor subnet : Extracts features (CNN, Fully connected, ...) from the raw data in order to feed tha HoM Layer
		- The classification subnet : From the features captured by the capsules of the HoM Layer, computes the distances to the center of the target.
		- The decoder subnet : From the features captured by the capsules that are masked by the correct class, reconstructs the image.

	Inputs : 	main_input: shape=[None, image_height, image_width, image_channels]
				label : shape=[None, num_classes]
	Outputs:	prediction: shape=[None, num_classes]
				reconstruction: shape=[None, image_height, image_width, image_channels]

	Parameters:	input_shape: tuple of the shape of the input data
				num_classes: number of classes to classify the data into
				dim_capsule: dimension of each capsule, determines the number of features per capsule
	
	Note : 	The feature extractor subNet and the decoder subnet can be modified to fit the input data
			or include more complex layers for examples.

	"""

	main_input = Input(shape=input_shape, dtype='float32', name='main_input')

	feature_maps = _Feature_Extractor_SubNet(main_input)

	HoM = HoM_Layer(feature_maps, num_classes, dim_capsule)

	prediction = Classification_SubNet(HoM)

	# Masking the capsules of the HoM by the label for the decoder

	label_input = Input(shape=(num_classes,), name="label_input")

	HoM_masked = Mask()([HoM, label_input])

	reconstruction = _Decoder_SubNet(HoM_masked, input_shape)

	model = Model([main_input, label_input], [prediction, reconstruction])

	return model






def _Feature_Extractor_SubNet(input_layer):
	"""
	CAN BE MODIFIED BY THE USER (to fit its particular problem, for example, can be replaced by a ResNet architecture)
	SubNet to extract the features that will feed the HoM layer. It must extract the features as feature maps or feature vectors.
	These features will then be flattened and processed to capsules by the HoM layer.

	Input : input_layer: layer containing the input data

	Output: returns the network with the added layers

	"""

	subnet = Conv2D(256, kernel_size=(9,9), strides=(1,1), activation='relu', padding='valid', name='convolution_1')(input_layer)
	subnet = Conv2D(256, kernel_size=(9,9), strides=(2,2), activation='relu', padding='valid', name="convolution_2")(subnet)

	return subnet



def HoM_Layer(input_layer, num_classes, dim_capsule):
	"""
	Hit or Miss layer as described in the paper. 
	Adds a Hit or Miss (HoM) layer to the network

	Inputs : input_layer: layer containing the features extracted by the feature extractor subnet
			 num_classes: number of classes to classify the data into
			 dim_capsule: dimension of each capsule, determines the number of features per capsule
	Output: returns the network with the added layers

	"""


	HoM = Flatten()(input_layer)
	HoM = Dense(num_classes*dim_capsule)(HoM)
	HoM = Reshape((num_classes,dim_capsule))(HoM)
	HoM = BatchNormalization()(HoM)
	HoM = Activation('sigmoid', name="HoM")(HoM)

	return HoM



def Classification_SubNet(input_layer):
	"""
	Classification subnet to compute the distances to the center of the target.
	
	Inputs : input_layer: layer containing the capsules of the HoM layer

	Outputs : the network with the prediction layer

	"""


	TARGET_CENTER = 0.5

	classification = Lambda(lambda x: TARGET_CENTER-x)(input_layer)
	classification = Length(name='prediction')(classification)

	return classification

def _Decoder_SubNet(input_layer, input_shape):
	"""
	Decoder subnet to reconstruct the image from the masked capsules of the HoM Layer.
	
	Inputs : 	input_layer: layer containing the masked capsules of the HoM layer
				input_shape: tuple of the shape of the input data
	Outputs : the network with the reconstruction.

	"""

	decoder = Dense(512, activation='relu')(input_layer)
	decoder = Dense(1024, activation='relu')(decoder)
	decoder = Dense(np.prod(input_shape), activation='sigmoid')(decoder)
	decoder = Reshape(input_shape, name='reconstruction')(decoder)

	return decoder



class Length(layers.Layer):
	"""
	Class and comment taken as is from : [REFERENCE]
	
	Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss.
	Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)`
	
	Inputs: shape=[None, num_vectors, dim_vector]
	
	Output: shape=[None, num_vectors]
	"""

	def call(self, inputs, **kwargs):
		return K.sqrt(K.sum(K.square(inputs), -1))

	def compute_output_shape(self, input_shape):
		return input_shape[:-1]


class Mask(layers.Layer):
	"""
	Class and comment taken as is from : [REFERENCE]

	Mask a Tensor with shape=[None, num_capsule, dim_vector] either by the capsule with max length or by an additional 
	input mask. Except the max-length capsule (or specified capsule), all vectors are masked to zeros. Then flatten the
	masked Tensor.

	For example:
		```
		x = keras.layers.Input(shape=[8, 3, 2])  # batch_size=8, each sample contains 3 capsules with dim_vector=2
		y = keras.layers.Input(shape=[8, 3])  # True labels. 8 samples, 3 classes, one-hot coding.
		out = Mask()(x)  # out.shape=[8, 6]
		# or
		out2 = Mask()([x, y])  # out2.shape=[8,6]. Masked with true labels y. Of course y can also be manipulated.
		```
	"""

	def call(self, inputs, **kwargs):
		if type(inputs) is list: 
			assert len(inputs) == 2
			inputs, mask = inputs
		else: 
			x = K.sqrt(K.sum(K.square(inputs), -1))
			mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])

		masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
		return masked

	def compute_output_shape(self, input_shape):
		if type(input_shape[0]) is tuple: 
			return tuple([None, input_shape[0][1] * input_shape[0][2]])
		else:
			return tuple([None, input_shape[1] * input_shape[2]])
