import os
import cv2
import numpy as np
import PyCVPR18AC.Utils.statistics as stats
import PyCVPR18AC.Semantics.tinynet as tinynet
import PyCVPR18AC.Utils.filterOperations as filters


class LineExtraction:

	def __init__(self):

		self.PATHPARAMS = os.path.dirname(os.path.abspath(__file__)) + "/params/"
		self.mean = np.load(self.PATHPARAMS + "normalization_mean.npy")
		self.std = np.load(self.PATHPARAMS + "normalization_std.npy")

		self.path_weights = self.PATHPARAMS + "model_weights.h5"

		self.data_parameters = {
			"input_shape":(1080,1920,3),
			"num_classes":2
		}

		self.Network = tinynet.Network()
		self.Network.build_training_network(self.data_parameters)
		self.Network.train_model.load_weights(self.path_weights)

	def compute_line_mask(self, frame):


		# Preprocessing and normalization of the image

		batch_image = np.expand_dims(np.copy(frame), axis = 0)

		batch_image = batch_image.astype('float32')

		for i in np.arange(batch_image.shape[-1]):
			batch_image[:,:,:,i] = batch_image[:,:,:,i] - self.mean[i]
			batch_image[:,:,:,i] = np.true_divide(batch_image[:,:,:,i],self.std[i])

		predictions = self.Network.train_model.predict(batch_image)
		segmentation = predictions.argmax(-1).astype('uint8')
		segmentation = segmentation[0,:]*255
		segmentation = np.reshape(segmentation,(self.data_parameters["input_shape"][0],self.data_parameters["input_shape"][1]))
		
		return segmentation