import numpy as np

import keras
from keras.models import Sequential
from keras.models import Model
from keras import backend as K
from keras.layers import Conv2D, MaxPooling2D, Lambda, Concatenate, Maximum, TimeDistributed, Reshape, Add, Subtract
from keras.layers import Dense, Dropout, Flatten, Input, Activation, BatchNormalization, AveragePooling2D

from keras.backend import spatial_2d_padding
from keras.layers import MaxPooling2D, UpSampling2D, Conv2DTranspose

import tensorflow as tf

class Network:

	def __init__(self):
		self.train_model = None
		self.test_model = None
		self.model = None
		K.set_image_data_format('channels_last')

	def build_training_network(self, data_parameters):

		input_shape = data_parameters["input_shape"]
		num_classes =data_parameters["num_classes"]

		main_input = Input(shape=input_shape, dtype='float32', name='main_input')


		subNet1 = Conv2D(64, kernel_size=(3, 3), strides=(2,2), padding='same', name='conv1_1')(main_input)
		subNet1 = BatchNormalization()(subNet1)
		subNet1 = Activation('relu')(subNet1)	# SHAPE : (?,540,960,64)

		subNet1 = Conv2D(64, kernel_size=(3, 3), strides=(1,1), padding='same', name='conv1_2')(subNet1)
		subNet1 = BatchNormalization()(subNet1)
		subNet1 = Activation('relu')(subNet1)	# SHAPE : (?,540,960,64)
		subNet1 = MaxPooling2D(pool_size=(2,2), strides = (2,2), padding='same', data_format = "channels_last")(subNet1)

		subNet1 = Conv2D(64, kernel_size=(3, 3), strides=(1,1), padding='same', name='conv1_4')(subNet1)
		subNet1 = BatchNormalization()(subNet1)
		subNet1 = Activation('relu')(subNet1)	# SHAPE : (?,270,480,64)
		subNet1 = MaxPooling2D(pool_size=(2,2), strides = (2,2), padding='same', data_format = "channels_last")(subNet1)	# SHAPE : (?,135,240,64)


		subNet2_1 = Conv2D(64, kernel_size=(3, 3), strides=(1,1), padding='same', name='conv2_1_1')(subNet1)
		subNet2_1 = BatchNormalization()(subNet2_1)
		subNet2_1 = Activation('relu')(subNet2_1)	# SHAPE : (?,135,240,64)

		subNet2_1 = Conv2D(64, kernel_size=(9, 9), strides=(1,1), padding='same', name='conv2_1_2')(subNet2_1)
		subNet2_1 = BatchNormalization()(subNet2_1)
		subNet2_1 = Activation('relu')(subNet2_1)	# SHAPE : (?,135,240,64)

		subNet2_1 = Conv2D(64, kernel_size=(3, 3), strides=(1,1), padding='same', name='conv2_1_3')(subNet2_1)
		subNet2_1 = BatchNormalization()(subNet2_1)
		subNet2_1 = Activation('relu')(subNet2_1)	# SHAPE : (?,135,240,64)


		subNet2 = Add()([subNet2_1, subNet1])	# SHAPE : (?,135,240,64)


		subNet3 = Conv2D(64, kernel_size=(3, 3), strides=(1,1), padding='same', name='conv3_1')(subNet2)
		subNet3 = BatchNormalization()(subNet3)
		subNet3 = Activation('relu')(subNet3)	# SHAPE : (?,135,240,64)


		pyrBranch1 = AveragePooling2D(pool_size = (135,240), strides = (135,240), padding='valid', data_format="channels_last")(subNet3)	# SHAPE : (?,1,1,64)
		pyrBranch1 = Activation('relu')(pyrBranch1)
		pyrBranch1 = UpSampling2D(size=(135,240), data_format="channels_last")(pyrBranch1)	# SHAPE : (?,135,240,64)



		pyrBranch2 = AveragePooling2D(pool_size = (45,120), strides = (45,120), padding='valid', data_format="channels_last")(subNet3)	# SHAPE : (?,2,2,64)
		pyrBranch2 = Conv2D(64, kernel_size=(2,2), strides=(1,1), padding='same', name="conv4_2_1")(pyrBranch2)	# SHAPE : (?,1,1,64)
		pyrBranch2 = Activation('relu')(pyrBranch2)
		pyrBranch2 = UpSampling2D(size=(45,120), data_format="channels_last")(pyrBranch2)	# SHAPE : (?,135,240,64)


		pyrBranch3 = AveragePooling2D(pool_size = (27,80), strides = (27,80), padding='valid', data_format="channels_last")(subNet3)	# SHAPE : (?,3,3,256)
		pyrBranch3 = Conv2D(64, kernel_size=(2,2), strides=(1,1), padding='same', name="conv4_3_1")(pyrBranch3)	# SHAPE : (?,3,3,64)
		pyrBranch3 = Activation('relu')(pyrBranch3)
		pyrBranch3 = UpSampling2D(size=(27,80), data_format="channels_last")(pyrBranch3)	# SHAPE : (?,135,240,64)




		pyrBranch4 = AveragePooling2D(pool_size = (15,30), strides = (15,30), padding='valid', data_format="channels_last")(subNet3)	# SHAPE : (?,6,6,256)
		pyrBranch4 = Conv2D(64, kernel_size=(2,2), strides=(1,1), padding='same', name="conv4_4_1")(pyrBranch4)	# SHAPE : (?,6,6,64)
		pyrBranch4 = Activation('relu')(pyrBranch4)
		pyrBranch4 = UpSampling2D(size=(15,30), data_format="channels_last")(pyrBranch4)	# SHAPE : (?,135,240,64)


		subNet4 = Concatenate()([subNet3, pyrBranch1, pyrBranch2, pyrBranch3, pyrBranch4])	# SHAPE : (?,68,120,64)
		subNet4 = BatchNormalization()(subNet4)


		subNet5 = Conv2D(64, kernel_size=(2, 2), strides=(1,1), padding='same', name='conv5_1')(subNet4)
		subNet5 = Activation('relu')(subNet5)	# SHAPE : (?,135,240,64)

		subNet5 = Dropout(0.25)(subNet5)

		subNet5 = UpSampling2D(size=(4,4), data_format="channels_last")(subNet5)	# SHAPE : (?,540,960,64)


		subNet5 = Conv2D(num_classes, kernel_size=(2, 2), strides=(1,1), padding='same', name='conv5_2')(subNet5)
		subNet5 = Activation('relu')(subNet5)	# SHAPE : (?,540,960,64)

		subNet5 = UpSampling2D(size=(2,2), data_format="channels_last")(subNet5)	# SHAPE : (?,1080,1920,64)

		subNet5 = Conv2D(num_classes, kernel_size=(2, 2), strides=(1,1), padding='same', name='conv5_3')(subNet5)	# SHAPE : (?,66,120,num_classes)


		subNet5 = Reshape((input_shape[0]*input_shape[1],num_classes))(subNet5)		# SHAPE : (?,1080*1920,num_classes)

		SegmentationMap = Activation('softmax')(subNet5)
		#SegmentationMap = ri.ResizeImages((input_shape[0], input_shape[1]),"channels_last")(subNet5)
	
		
		self.train_model = Model(inputs = main_input, outputs = SegmentationMap)


	def format_data(self, data, data_parameters = None):
		mask_semantic = np.reshape(data["map"],(data["map"].shape[0],data["map"].shape[1]*data["map"].shape[2]))
		mask_background = 1-mask_semantic
		mask = np.stack((mask_background,mask_semantic), axis = 2)
		inout = [ data["input"], mask ]
		return inout

	def format_data_for_confusion_matrix(self, data, data_parameters = None):
		return np.reshape(data["map"],(data["map"].shape[0],data["map"].shape[1]*data["map"].shape[2]))
