import os
import cv2
import math
import numpy as np
from scipy import ndimage
import PyCVPR18AC.Utils.constants as const
import PyCVPR18AC.Utils.statistics as stats
import PyCVPR18AC.Semantics.tinynet as tinynet
import PyCVPR18AC.Utils.filterOperations as filters


class PlayerExtraction:

	def __init__(self):

		self.PATHPARAMS = os.path.dirname(os.path.abspath(__file__)) + "/params/"
		self.mean = np.load(self.PATHPARAMS + "normalization_mean.npy")
		self.std = np.load(self.PATHPARAMS + "normalization_std.npy")

		self.path_weights = self.PATHPARAMS + "model_weights.h5"

		self.data_parameters = {
			"input_shape":(1080,1920,3),
			"num_classes":2
		}

		self.Network = tinynet.Network()
		self.Network.build_training_network(self.data_parameters)
		self.Network.train_model.load_weights(self.path_weights)

	def compute_player_mask(self, frame):


		# Preprocessing and normalization of the image

		batch_image = np.expand_dims(np.copy(frame), axis = 0)

		batch_image = batch_image.astype('float32')

		for i in np.arange(batch_image.shape[-1]):
			batch_image[:,:,:,i] = batch_image[:,:,:,i] - self.mean[i]
			batch_image[:,:,:,i] = np.true_divide(batch_image[:,:,:,i],self.std[i])

		predictions = self.Network.train_model.predict(batch_image)
		segmentation = predictions.argmax(-1).astype('uint8')
		segmentation = segmentation[0,:]*255
		segmentation = np.reshape(segmentation,(self.data_parameters["input_shape"][0],self.data_parameters["input_shape"][1]))
		
		return segmentation

def compute_mean_position(segmentation):
	segmentation = segmentation.astype('bool')
	measure = ndimage.measurements.center_of_mass(segmentation)
	barycenter = [int(round(segmentation.shape[1]/2)),int(round(segmentation.shape[0]/2))]
	
	if np.sum(segmentation) > 0:
		barycenter[0] = measure[1]
		barycenter[1] = measure[0]
	return barycenter

def compute_groupment(segmentation):

	points = cv2.findNonZero(segmentation)

	if points is None:
		return const.NO_VALUE

	points = points[:,0,:]

	segmentation = segmentation.astype('bool')

	barycenter = compute_mean_position(segmentation)

	groupment = np.sum(np.sqrt((points[:,0]-barycenter[0])**2+(points[:,1]-barycenter[1])**2))

	return groupment/points.shape[0]


def compute_distance(mean_position, field_map, main_circle):

	distance_ellipse = const.NO_DISTANCE
	distance_field = const.NO_DISTANCE


	# mean_position[0] -> x ; mean_position[1] -> y

	x_barycenter = int(math.floor(mean_position[0]))
	y_barycenter = int(math.floor(mean_position[1]))

	
	# First check if the main circle exists and comute horizontal distance to center

	if main_circle.exists:
		distance_ellipse = abs(x_barycenter - main_circle.center_x)

	# Else, we find the left and right field contour

	field_map_horizontal = field_map[y_barycenter,:] 

	itemindex = np.where(field_map_horizontal > 0)

	left_item = itemindex[0][0]
	right_item = itemindex[0][-1]

	distance_left, distance_right = 0, 0

	if left_item > 1:
		distance_left = abs(x_barycenter - left_item)

	if field_map.shape[1] - right_item > 2:
		distance_right = abs(x_barycenter - right_item)

	if distance_left == 0 and distance_right == 0:

		#Else, we find the upper field contour to get the left and right field contour

		field_map_vertical = field_map[:,x_barycenter]

		itemindex = np.where(field_map_vertical > 0)

		up_pixel = [mean_position[0], itemindex[0][0]]

		
		field_map_horizontal_from_vertical = field_map[up_pixel[1],:] 

		itemindex = np.where(field_map_horizontal_from_vertical > 0)

		left_item = itemindex[0][0]
		right_item = itemindex[0][-1]

		distance_left, distance_right = 0, 0

		if left_item > 1 and abs(left_item-up_pixel[0]) > 5:
			distance_left = abs(x_barycenter - left_item)

		if field_map.shape[1] - right_item > 2 and abs(right_item-up_pixel[0])>5:
			distance_right = abs(x_barycenter - right_item)


	if distance_left >0:
		if distance_left > distance_right:
			distance_field = distance_left
		elif distance_right > 0:
			distance_field = distance_right

	

	return [distance_ellipse, distance_field]

def compute_direction(distance, distance_memory):

	if distance[0] != const.NO_DISTANCE and distance_memory[0] != const.NO_DISTANCE:
		if distance[0] - distance_memory[0] > 0:
			return const.TOWARD_GOAL
		elif distance[0] - distance_memory[0] < 0:
			return const.TOWARD_CENTER

	if distance[1] != const.NO_DISTANCE and distance_memory[1] != const.NO_DISTANCE:
		if distance[1] - distance_memory[1] > 0:
			return const.TOWARD_CENTER
		elif distance[1] - distance_memory[1] < 0:
			return const.TOWARD_GOAL

	return const.STATIC