import numpy as np
import keras
from keras import backend as K


class Centripetal_Loss:
	"""
	Centripetal Loss Class. The goal is to embed the dimension of the capsule
	which is used in the miss loss. 
	
	"""

	def __init__(self, dim_capsule=16, with_ghost=False):
		"""
		Parameters : 	dim_capsule : dimension of each capsule, determines the number of features per capsule
								with_ghost : set to True if the loss must incorporate one ghost capsule during training
		"""
		self.dim_capsule = dim_capsule
		self.with_ghost = with_ghost

	def loss(self, y_true, y_pred):
		"""
		Select the loss; whether with or without a ghost capsule
		"""
		
		if self.with_ghost:
			return self._ghost_centripetal_loss(y_true, y_pred)
		return self._centripetal_loss(y_true, y_pred)

	def _centripetal_loss(self, y_true, y_pred):
		
		### For the hits ###

		#The variables are named as in the paper

		m = 0.1 # Corresponds to the Hit zone threshold
		l = 0.1 # Width of the step (horizontal) in the step function
		h = 0.2 # Height of the step (vertical) in the step function
		f = K.round((y_pred-m)/l - 0.5) # Floor function (thanks to the -0.5 since K.floor does not exist yet with all the backends) 
		H = K.sign(K.maximum(0., y_pred-m))
		L1 = H*(f*(f+1)*h*l*0.5+(f+1)*h*(y_pred-m-f*l))

		### for the misses ###

		# Instead of having an adapted loss (decreasing) with m_prime = 0.9, we provide m_prime_reversed = (sqrt(n)/2)-m_prime.
		
		reversed_factor = np.sqrt(self.dim_capsule)/2

		m_prime = 0.9 # Corresponds to the Miss zone threshold
		m_prime_reversed = reversed_factor-m_prime 
		l = 0.1 # Width of the step (horizontal) in the step function
		h = 0.2 # Height of the step (vertical) in the step function
		y_pred_reversed = reversed_factor-y_pred
		f = K.round((y_pred_reversed-m_prime_reversed)/l - 0.5) # Floor function (thanks to the -0.5 since K.floor does not exist yet with all the backends) 
		H = K.sign(K.maximum(0., y_pred_reversed-m_prime_reversed))
		L2 = H*(f*(f+1)*h*l*0.5+(f+1)*h*(y_pred_reversed-m_prime_reversed-f*l))  

		
		L = y_true * L1 + 0.5 * (1 - y_true) * L2

		return K.mean(K.sum(L, 1))

	def _ghost_centripetal_loss(self, y_true, y_pred):
		
		### For the hits ###

		#The variables are named as in the paper

		m = 0.1 # Corresponds to the Hit zone threshold
		l = 0.1 # Width of the step (horizontal) in the step function
		h = 0.2 # Height of the step (vertical) in the step function
		f = K.round((y_pred-m)/l - 0.5) # Floor function (thanks to the -0.5 since K.floor does not exist yet with all the backends) 
		H = K.sign(K.maximum(0., y_pred-m))
		L1 = H*(f*(f+1)*h*l*0.5+(f+1)*h*(y_pred-m-f*l))

		### for the misses ###

		# Instead of having an adapted loss (decreasing) with m_prime = 0.9, we provide m_prime_reversed = (sqrt(n)/2)-m_prime.
		
		reversed_factor = np.sqrt(self.dim_capsule)/2


		m_prime = 0.9
		m_prime_reversed = reversed_factor-m_prime # Corresponds to the Miss zone threshold
		l = 0.1 # Width of the step (horizontal) in the step function
		h = 0.2 # Height of the step (vertical) in the step function
		y_pred_reversed = reversed_factor-y_pred
		f = K.round((y_pred_reversed-m_prime_reversed)/l - 0.5) # Floor function (thanks to the -0.5 since K.floor does not exist yet with all the backends) 
		H = K.sign(K.maximum(0., y_pred_reversed-m_prime_reversed))
		L2 = H*(f*(f+1)*h*l*0.5+(f+1)*h*(y_pred_reversed-m_prime_reversed-f*l))  

		# Choice of the ghost capsule
		z = (1 - y_true) * y_pred + y_true * K.tile(K.max(y_pred, 1, keepdims=True), (1, y_pred.get_shape().as_list()[1]))
		y_true_tilde = K.one_hot(indices=K.argmin(z, 1), num_classes=y_pred.get_shape().as_list()[1]) + y_true

		L = y_true * L1 + 0.5 * (1 - y_true_tilde) * L2

		return K.mean(K.sum(L, 1))


def argmin_metric(y_true, y_pred):
	return K.cast(K.equal(K.argmax(y_true, axis=-1), K.argmin(y_pred, axis=-1)), K.floatx())

