import os
import sys
import keras
import numpy as np

from hitnet import HitNet
from data_processing import train_generator, preprocessing

from keras import optimizers
from keras import backend as K
from keras.utils import plot_model
from keras.datasets import mnist
from keras.callbacks import ModelCheckpoint, CSVLogger, TensorBoard

from centripetal_loss import Centripetal_Loss, argmin_metric




# Setting the parameters and arguments

K.set_image_data_format('channels_last')

dim_capsule = 16

loss = Centripetal_Loss(dim_capsule=dim_capsule, with_ghost=True)

training_parameters = { "batch_size":128,
						"epochs":250,
						"learning_rate":0.001,
						"lambda_reconstruction":0.392,
						"loss":loss.loss,
}


log = CSVLogger('./Logs/log.csv')
log2 = CSVLogger('./Logs/log.txt')
tensorboard = TensorBoard(log_dir="./TensorBoard/", batch_size=training_parameters["batch_size"], histogram_freq=0, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None)


# Importing the dataset

(x_train, y_train), (x_test, y_test) = mnist.load_data()

input_shape=(28,28,1)
num_classes = 10

# Preprocessing the data by rescaling the image and categorizing the labels

x_train, y_train = preprocessing(x_train, y_train, input_shape, num_classes)
x_test, y_test = preprocessing(x_test, y_test, input_shape, num_classes)

# Creating the HitNet model and printing its structure

Network = HitNet(input_shape, num_classes, dim_capsule=dim_capsule)

print("Summary of the network structure")
print(Network.summary())
plot_model(Network, to_file='./Images/Hitnet_Structure.png', show_shapes = True)



# Training the network 

Network.compile(optimizer=optimizers.Adam(lr=training_parameters["learning_rate"]),
				loss=[training_parameters["loss"],'mse'],
				loss_weights=[1., training_parameters["lambda_reconstruction"]],
				metrics={'prediction':argmin_metric})

Network.fit_generator(generator=train_generator(x_train, y_train, batch_size=training_parameters["batch_size"], shift_fraction=0.1),
					steps_per_epoch=int(y_train.shape[0] /training_parameters["batch_size"]),
					epochs=training_parameters["epochs"],
					verbose = 1,
					validation_data=[[x_test, y_test], [y_test, x_test]],
					callbacks=[log, log2, tensorboard])

Network.save_weights('trained_model.h5')