import numpy as np
import keras
from keras.preprocessing.image import ImageDataGenerator


def preprocessing(x, y, input_shape, num_classes):
	"""
	Preprocessing of the data for easier training and compliance with the 
	network's needs

	Inputs : 	x : the images to rescale
				y : the integer labels to categorize
				input_shape: tuple of the shape of the input data
				num_classes: number of classes to classify the data into

	Outputs : 	x: the scaled images
				y: the categorized labels
	"""

	x = x.reshape(-1, input_shape[0], input_shape[1], input_shape[2]).astype('float32') /255.

	y = keras.utils.to_categorical(y.astype('float32'), num_classes=num_classes)

	return x, y

def train_generator(x, y, batch_size, shift_fraction=0.):
	"""
	Function taken as is from : [REFERENCE]

	This function generates data which are similar to the original data 
	but shifted by a percentage of shift_fraction

	This is a common data augmentation technique

	"""
	train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
									   height_shift_range=shift_fraction)
	generator = train_datagen.flow(x, y, batch_size=batch_size)
	while 1:
		x_batch, y_batch = generator.next()
		yield ([x_batch, y_batch], [y_batch, x_batch])