import time
import os
import cv2
import dlib
import numpy as np
from scipy.spatial import distance
from skimage.metrics import structural_similarity as compare_ssim
from sklearn.metrics import mean_absolute_error
from skimage.metrics import peak_signal_noise_ratio
from matplotlib import pyplot as plt
from matplotlib import gridspec

# Initialize dlib's face detector and then create the facial landmark predictor
#detector = dlib.get_frontal_face_detector()
# Enter the path of dlib's shape predictor model
#extractor = dlib.shape_predictor('./shape_predictor_68_face_landmarks_GTX.dat')

def tic():
    global _start_time
    _start_time = time.time()


def tac():
    t_sec = round(time.time() - _start_time,2)
    (t_min, t_sec) = divmod(t_sec,60)
    (t_hour,t_min) = divmod(t_min,60)
    print('Time passed: {}min:{}sec'.format(t_min,t_sec))


def detect_face(image):
    """Detect the first face on an image and return the bounding box [x,y,w,h]"""
    # Initialize hog + svm based face detector
    detector = dlib.get_frontal_face_detector()
    # Detect faces in image. Second argument is the number of image pyramid layers to apply when upscaling
    # the image prior to applying the detector
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    faces = detector(gray, 0)
    if len(faces) > 0:
        # x = faces[0].left()
        # y = faces[0].top()
        # w = faces[0].right() - x
        # h = faces[0].bottom() - y
        # # draw box over face
        # cv2.rectangle(image, (x,y), (x+w,y+h), (0,255,0), 2)
        # face = [x, y, w, h]
        return faces[0]
    else:
        print("No face")
    return faces[0]


def crop_face_only(image):
    """Crop an image around the inner face"""
    detector = dlib.get_frontal_face_detector()
    # Detect faces in gray
    rects = detector(image)
    # Get the first face if there are any
    if len(rects) < 1:
        print("... No face detected")
        w0, w1, h0, h1 = 0
    else:
        bbox = [int(rects[0].left()), int(rects[0].top()), int(rects[0].right()), int(rects[0].bottom())]
        w0 = bbox[0]+int(0.1*(bbox[2]-bbox[0]))
        w1 = bbox[2]-int(0.1*(bbox[2]-bbox[0]))
        h0 = bbox[1]+int(0.1*(bbox[3]-bbox[1]))
        h1 = bbox[3]-int(0.1*(bbox[3]-bbox[1]))

    return image[h0:h1, w0:w1]


def crop_face_only_bigger(image):
    """Crop an image around the inner face"""
    detector = dlib.get_frontal_face_detector()
    # Detect faces in gray
    rects = detector(image)
    # Get the first face if there are any
    if len(rects) < 1:
        print("... No face detected")
        w0, w1, h0, h1 = 0
    else:
        bbox = [int(rects[0].left()), int(rects[0].top()), int(rects[0].right()), int(rects[0].bottom())]
        w0 = bbox[0]+int(0.1*(bbox[2]-bbox[0]))
        w1 = bbox[2]-int(0.1*(bbox[2]-bbox[0]))
        h0 = bbox[1]+int(0.1*(bbox[3]-bbox[1]))
        h1 = bbox[3]-int(0.1*(bbox[3]-bbox[1]))
        w_plus = (w1 - w0)//4
        h_plus = (h1 - h0)//4
        if h0-h_plus < 0 :
            hh0 = 0
        else:
            hh0 = h0-h_plus
        if h1+h_plus > image.shape[0]:
            hh1 = image.shape[0]
        else:
            hh1 = h1+h_plus
        if w0-w_plus < 0:
            ww0 = 0
        else:
            ww0 = w0-w_plus
        if w1+w_plus > image.shape[1]:
            ww1 = image.shape[1]
        else:
            ww1 = w1+w_plus
    return image[hh0:hh1, ww0:ww1]


def crop_icao(image, wratio=0.55, hratio=0.6):
    """Crop an image to make it ICAO compliant"""
    landmarks = extract_landmarks(image, detect_face(image))

    iho = image.shape[0]
    iwo = image.shape[1]

    # Left as we see the image
    width_left_condition = (landmarks[8][0] - landmarks[1][0])/landmarks[8][0]
    width_right_condition = (landmarks[15][0] - landmarks[8][0])/(iwo - landmarks[8][0])
    height_bottom_condition = (landmarks[8][1] - landmarks[28][1])/(iho - landmarks[28][1])

    # print('Left face to width ratio = ',width_left_condition,'. Right face to width ratio = ',width_right_condition,
    #       '. Bottom half face to height ratio = ',height_bottom_condition,'\n')
########################### LEFT #################################################################
    #if 0.5 <= width_left_condition <= 0.75:
    if width_left_condition >= wratio:
        #print('Left side is OK')
        wcrop_from_left = 0
    else:
        wcrop_from_left = int((landmarks[1][0] - (1 - wratio)*landmarks[8][0])/wratio)
        #print("Cropping %f pixels from left" %(wcrop_from_left))
########################### RIGHT #################################################################
    #if 0.5 <= width_right_condition <= 0.75:
    if width_right_condition >= wratio:
        #print('Right side is OK')
        wcrop_from_right = iwo
    else:
        wcrop_from_right = int((landmarks[15][0] - (1 - wratio)*landmarks[8][0] - wratio*iwo)/wratio)
        #print("Cropping %f pixels from right" %(-wcrop_from_right))
########################### BOTTOM #################################################################
    #if 0.6 <= height_bottom_condition <= 0.9:
    if height_bottom_condition >= hratio:
        #print('Bottom side is OK')
        hcrop_from_bottom = iho
    else:
        hcrop_from_bottom = iho - int(((1 - hratio)*landmarks[28][1] - landmarks[8][1] + hratio*iho)/hratio)
        #print("Cropping %f pixels from bottom"%(iho-hcrop_from_bottom),'\n')
######################################################################################################
    cropped = image[0:hcrop_from_bottom, wcrop_from_left:wcrop_from_right]

    #print("Current W/H ratio: %f"%(cropped.shape[1]/cropped.shape[0]),'\n')
    # try:
    #     cv2.imshow('initial crop',cropped)
    # except cv2.error:
    #     pass
    landmarks = extract_landmarks(cropped, detect_face(cropped))

    # Check if the line between the eyes over the image height is more than 45% (low in the image)
    check = 0.45

    if ((landmarks[39][1] + landmarks[42][1])/2)/cropped.shape[0] > check:
        cut = (((landmarks[39][1] + landmarks[42][1])/2) - cropped.shape[0]*check)/(1-check)
    else:
        cut = 0
    cropped2 = cropped[int(cut):,:]

    # Reduce the crop if W/H ratio becomes more than 0.8
    check2 = (cropped[int(cut):,:].shape[1]/cropped[int(cut):,:].shape[0])
    if (cropped[int(cut):,:].shape[1]/cropped[int(cut):,:].shape[0]) >= 0.7789:
        cut -= 0.1
        print("cut",cut)
        print((cropped[int(cut):,:].shape[1]/cropped[int(cut):,:].shape[0]))

        # if check2 < 0.7789:
        #     break


    #print("After initial top cut, W/H ratio is %f" %(cropped2.shape[1]/cropped2.shape[0]), "\n")

    # Ensure W/H ratio is OK
    if (cropped2.shape[1]/cropped2.shape[0]) >= 0.778:
        cropped3 = cropped2
    else:
        final_cut_top = cropped2.shape[0] - (cropped2.shape[1])/0.778
        cropped3 = cropped2[int(final_cut_top):,:]

    try:
        print('Final width to height ratio = %f' % (cropped3.shape[1]/cropped3.shape[0]), "\n")
    except ZeroDivisionError:
        pass

    return cropped2


def exact_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
    # initialize the dimensions of the image to be resized and
    # grab the image size
    dim = None
    (h, w) = image.shape[:2]

    # if both the width and height are None, then return the
    # original image
    if width is None and height is None:
        return image

    # check to see if the width is None
    if width is None:
        # calculate the ratio of the height and construct the
        # dimensions
        r = height / float(h)
        dim = (int(w * r), height)

    # otherwise, the height is None
    else:
        # calculate the ratio of the width and construct the
        # dimensions
        r = width / float(w)
        dim = (width, int(h * r))


    # resize the image
    resized = cv2.resize(image, dim, interpolation=inter)
    (h, w) = resized.shape[:2]
    if h > height:
        cut_up = (h - height)//2
        cut_d = (h - height)//2
        resized = resized[cut_up:height-cut_d,:]
    elif h < height:
        add_up = (height - h)//2
        add_d = (height - h)//2
        resized = cv2.copyMakeBorder(resized, int(add_up), int(add_d), 0, 0, cv2.BORDER_CONSTANT, value=0)
    # return the resized image
    return resized


def calc_roi_nose(image):
    face = detect_face(image)
    ld68 = extract_landmarks(image, face).astype(int)
    height = ld68[27][1] - ld68[21][1]
    width =  ld68[22][0] - ld68[21][0]
    midline = (ld68[20][1] + ld68[23][1])//2

    outroi = image[midline - height:midline, int(ld68[21][0] + 0.15*(ld68[22][0] - ld68[21][0])):int(ld68[22][0] - 0.15*(ld68[22][0] - ld68[21][0]))] # image[Height , Width]
    inroi = image[midline:midline + height, int(ld68[21][0] + 0.15*(ld68[22][0] - ld68[21][0])):int(ld68[22][0] - 0.15*(ld68[22][0] - ld68[21][0]))]
    #cv2.rectangle(image,(ld68[21][0],midline-half_h),(ld68[22][0],midline) , (0,0,255)) # landmarks[W,H] or [x,y]
    #cv2.rectangle(image,(ld68[21][0],midline,),(ld68[22][0],ld68[27][1]) , (0,255,255))
    # cv2.imshow('ff',image)
    # cv2.imshow('f',outroi)
    # cv2.imshow('fx',inroi)


    #cv2.waitKey()
    #print(outroi.shape,inroi.shape)
    return outroi, inroi


def calc_roi_cheek(image):
    face = detect_face(image)
    ld68 = extract_landmarks(image, face).astype(int)
    height = ld68[27][1] - ld68[21][1]
    width = ld68[22][0] - ld68[21][0]
    in_point = [ld68[46][0],ld68[29][1]] # Width, Height

    midline = (ld68[20][1] + ld68[23][1])//2

    outroi = image[(midline - int(1.5*height)):midline, int(ld68[21][0] - int(0.75*height)):int(ld68[22][0] + int(0.75*height))] # image[Height , Width]
    inroi = image[ld68[30][1] - outroi.shape[0]//2:ld68[30][1] + outroi.shape[0]//2, ld68[46][0] - outroi.shape[1]//2:ld68[46][0] + outroi.shape[1]//2]
    # outroi = image[(midline - int(1.5*height)):midline, int(ld68[21][0] - 0.5*(ld68[22][0] - ld68[21][0])):int(ld68[22][0] + 0.5*(ld68[22][0] - ld68[21][0]))] # image[Height , Width]
    # inroi = image[ld68[30][1] - outroi.shape[0]//2:ld68[30][1] + outroi.shape[0]//2, ld68[46][0] - outroi.shape[1]//2:ld68[46][0] + outroi.shape[1]//2]

    # Ensure equal size of ROIs
    if inroi.shape[0] < outroi.shape[0]:
        inroi = image[ld68[29][1] - (outroi.shape[0])//2:ld68[29][1] + outroi.shape[0]//2 + 1, ld68[46][0] - outroi.shape[1]//2:ld68[46][0] + outroi.shape[1]//2]
    if inroi.shape[1] < outroi.shape[1]:
        inroi = image[ld68[29][1] - outroi.shape[0]//2:ld68[29][1] + outroi.shape[0]//2, ld68[46][0] -  outroi.shape[1]//2:ld68[46][0] + outroi.shape[1]//2 + 1]
    #cv2.rectangle(image,(ld68[21][0],midline-half_h),(ld68[22][0],midline) , (0,0,255)) # landmarks[W,H] or [x,y]
    #cv2.rectangle(image,(ld68[21][0],midline,),(ld68[22][0],ld68[27][1]) , (0,255,255))
    # cv2.imshow('ff',image)
    # cv2.imshow('f',outroi)
    # cv2.imshow('fx',inroi)
    #
    #
    # cv2.waitKey()
    print(outroi.shape,inroi.shape)
    return outroi, inroi


def rotate_it(image, inter = cv2.INTER_CUBIC):
    """Rotate an image to align eyes horizontally if angle is larger than 5 degrees to either way"""
    # grab the dimensions of the image
    (h, w) = image.shape[:2]

    # rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    # extractor = dlib.shape_predictor('./shape_predictor_68_face_landmarks_GTX.dat')
    # landmarks = extractor(rgb, detect_face(image))
    # landmarks_found = shape_to_np(landmarks)
    landmarks_found = extract_landmarks(image,detect_face(image))
    leftEyeCenter = (landmarks_found[43] + landmarks_found[44] + landmarks_found[46] + landmarks_found[47])//4
    rightEyeCenter = (landmarks_found[37] + landmarks_found[38] + landmarks_found[40] + landmarks_found[41])//4
    # leftEyeCenter = landmarks_found[42]
    # rightEyeCenter = landmarks_found[39]

    center = ((leftEyeCenter[0] + rightEyeCenter[0]) // 2, (leftEyeCenter[1] + rightEyeCenter[1]) // 2)

    dY = (rightEyeCenter[1] - leftEyeCenter[1])
    dX = (rightEyeCenter[0] - leftEyeCenter[0])

    angle = np.degrees(np.arctan2(dY, dX)) - 180
    #print("angle",angle)

    # if angle >= -6:
    #     rotated = image
    # elif angle <= -369:
    #     rotated = image
    # else:
    scale = 1
    # perform the rotation
    M = cv2.getRotationMatrix2D(center, angle, scale)
    rotated = cv2.warpAffine(image, M, (w, h), flags=inter)
    # return the rotated image
    return rotated


def resize(image, factor, inter=cv2.INTER_CUBIC):
    """Resize image equally in width and height, by factor. factor > 1 means upscale.
     """
    image = cv2.resize(image,dsize=(int(image.shape[0]*factor),int(image.shape[0]*factor)),interpolation=inter)
    return image


def sharpen(image, sigma=1, strength=1, kernel='default'):
    """Sharpen an image. Sigma is the sigma of the blur filter. Strength is the strength of sharpening.
     Kernel is either a laplacian or a more aggressive one"""
    if kernel == 'default':
        selected_kernel = np.array(([[0, 1, 0], [1, -4, 1], [0, 1, 0]]))#,np.float32)/9
    elif kernel == 'extreme':
        selected_kernel = np.array(([[1, 1, 1], [1, -8, 1], [1, 1, 1]]))#,np.float32)/7

    sharp_full = np.zeros_like(image)
    for i in range(3):
        temp = image[:,:,i]
        blurred = cv2.medianBlur(temp, sigma)
        filtered = cv2.filter2D(src=blurred, kernel=selected_kernel, ddepth=cv2.CV_64F, borderType=cv2.BORDER_REFLECT_101)
        sharp = temp - strength * filtered
        sharp[sharp>255] = 255
        sharp[sharp<0] = 0
        sharp_full[:,:,i] = sharp

    # cv2.imshow('filt',sharp_full)
    # cv2.waitKey()
    return sharp_full


def extract_landmarks(image, face):
    """Extract DLIB landmarks from a face.L0[x,y]-L68[x,y]"""
    # Models should be initialised in the beginning???
    # Enter the path of dlib's shape predictor model
    #gray = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    extractor = dlib.shape_predictor('./shape_predictor_68_face_landmarks_GTX.dat')
    landmarks = extractor(image, face)

    landmarks = shape_to_np(landmarks)
    #Draw each of them
    # for (i, (x, y)) in enumerate(landmarks):
    #     cv2.circle(image, (x, y), 1, (0, 255, 0), -1)
    #     cv2.putText(image, str(i + 1), (x - 10, y - 10),
    #                cv2.FONT_HERSHEY_SIMPLEX, 0.35, (0, 0, 255), 1)
    # cv2.imshow("",image)
    # cv2.waitKey()
    return landmarks


def shape_to_np(shape, dtype="int"):
    # initialize the list of (x, y)-coordinates
    coords = np.zeros((shape.num_parts, 2), dtype=dtype)

    # loop over all facial landmarks and convert them
    # to a 2-tuple of (x, y)-coordinates
    for i in range(0, shape.num_parts):
        coords[i] = (shape.part(i).x, shape.part(i).y)

    # return the list of (x, y)-coordinates
    return coords

########################################### METRICS ###################################################################
def d_hamming(img1, img2):
    """Get Hamming distance between two images
    dim=1 returns average over the 3 channels
    dim=3 returns 1 distance per channel
    """
    dh = []
    if img1.shape[-1] == 3:
        blue_d = distance.hamming(np.ravel(img1[:,:,0]), np.ravel(img2[:,:,0]))
        green_d = distance.hamming(np.ravel(img1[:,:,1]), np.ravel(img2[:,:,1]))
        red_d = distance.hamming(np.ravel(img1[:,:,2]), np.ravel(img2[:,:,2]))
        dh = [round(blue_d,4), round(green_d,4), round(red_d,4)]
        dh = round(np.mean(dh),4)*100
    elif img1.shape[-1] != 3:
        dh = round(distance.hamming(np.ravel(img1), np.ravel(img2)),4) * 100

    return dh


def mae(img1,img2):
    """Mean absolute pixel differences (error) between 2 images.
    dim=1 returns average over the 3 channels
    dim=3 returns 1 error per channel
    """
    error = []
    img1 = np.asarray(img1, dtype=np.int8)
    img2 = np.asarray(img2, dtype=np.int8)
    if img1.shape[-1] == 3:
        blue_error = mean_absolute_error(img1[:,:,0], img2[:,:,0])
        green_error = mean_absolute_error(img1[:,:,1], img2[:,:,1])
        red_error = mean_absolute_error(img1[:,:,2], img2[:,:,2])
        error = [round(blue_error,4), round(green_error,4), round(red_error,4)]
        error = round(np.mean(error),4)
    elif img1.shape[-1] != 3:
        error = round(mean_absolute_error(img1, img2),4)

    return error


def mstd(img1, img2):
    """STD of absolute pixel differences (error) between 2 images.
    dim=1 returns average over the 3 channels
    dim=3 returns 1 error per channel
    """
    error_std = []
    img1 = np.asarray(img1, dtype=np.int8)
    img2 = np.asarray(img2, dtype=np.int8)
    if img1.shape[-1] == 3:
        blue_error_std = np.std(img1[:,:,0] - img2[:,:,0])
        green_error_std = np.std(img1[:,:,1] - img2[:,:,1])
        red_error_std = np.std(img1[:,:,2] - img2[:,:,2])
        error_std = [round(blue_error_std,4), round(green_error_std,4), round(red_error_std,4)]
        error_std = round(np.mean(error_std),4)
    elif img1.shape[-1] != 3:
        error_std = round((np.std(img1 - img2)),4)

    return error_std


def psnr(img1, img2):
    """ PSNR between 2 images.
    dim=1 returns average(?) over the 3 channels
    dim=3 returns 1 PSNR per channel
    """
    ratio = []
    img1 = img1.astype(np.uint8)
    img2 = img2.astype(np.uint8)
    if img1.shape[-1] == 3:
        blue_ratio = peak_signal_noise_ratio(img1[:, :, 0], img2[:, :, 0])
        green_ratio = peak_signal_noise_ratio(img1[:, :, 1], img2[:, :, 1])
        red_ratio = peak_signal_noise_ratio(img1[:, :, 2], img2[:, :, 2])
        ratio = [round(blue_ratio,4), round(green_ratio,4), round(red_ratio,4)]
        ratio = round(np.mean(ratio),4)
    elif img1.shape[-1] != 3:
        ratio = round(peak_signal_noise_ratio(img1, img2),4)

    return ratio


def ssim(img1, img2):
    """ Structural similarity index between 2 images.
    dim=1 returns average(?) over the 3 channels
    dim=3 returns 1 PSNR per channel
    """
    score = []
    diff = []
    if img1.shape[-1] == 3:
        blue_score, blue_diff = compare_ssim(img1[:,:,0], img2[:,:,0], full=True, multichannel=False)
        green_score, green_diff = compare_ssim(img1[:,:,1], img2[:,:,1], full=True, multichannel=False)
        red_score, green_diff = compare_ssim(img1[:,:,2], img2[:,:,2], full=True, multichannel=False)
        score = [round(blue_score,4),round(green_score,4),round(red_score,4)]
        score = round(np.mean(score),4) *100
    elif img1.shape[-1] != 3:
        score, diff = compare_ssim(img1, img2, full=True, multichannel=True)
        score = round(score,4) * 100
        #diff = (diff * 255).astype("uint8")

    # cv2.imshow('s',diff*10)
    # cv2.waitKey()
    #print("SSIM: {}".format(score))
    return score#(score*100)-90


def thres_diff(img1, img2, thres = 2):
    """How many pixels differ more than thres intensity values on average over 3 channels"""
    img1 = np.asarray(img1, dtype=np.int8)
    img2 = np.asarray(img2, dtype=np.int8)
    if img1.shape[-1] == 3:
        blue = np.abs(img1[:,:,0] - img2[:,:,0])
        green = np.abs(img1[:,:,1] - img2[:,:,1])
        red = np.abs(img1[:,:,2] - img2[:,:,2])
        bc = np.count_nonzero(blue > thres)
        gc = np.count_nonzero(green > thres)
        rc = np.count_nonzero(red > thres)
        average_count = (bc+gc+rc)/3
        average_count = 100 * (average_count/(img1.shape[0]*img1.shape[1]))
    elif img1.shape[-1] != 3:
        diff = np.abs(img1-img2)
        average_count = np.count_nonzero(diff > thres)
        average_count = 100 * (average_count/(img1.shape[0]*img1.shape[1]))

    return average_count

########################################## MEASURES #################################################################


def get_measures(img1,img2, thres):
    """Get an array with the measures. In case of RGB we average over the 3 channels"""
    measures = []
    ham_dist = d_hamming(img1, img2)
    thres = thres_diff(img1, img2, thres)
    mae_ = mae(img1, img2)
    #mstd_ = mstd(img1, img2)
    psnr_ = psnr(img1, img2)
    ssim_ = ssim(img1, img2)
    measures.append(ham_dist)
    measures.append(thres)
    measures.append(mae_)
    #measures.append(mstd_)
    measures.append(psnr_)
    measures.append(ssim_)
    #measures = np.asarray(measures,dtype=np.float32)
    #measures = np.reshape(measures,5)
    return measures

############################################# BLOCKS #################################################################
def block(img, n_blocks=1):
    """Divide an image to nxn blocks. We pad the image top and left sides with zeros for precises divisions (So that H,W is a multiple of number of blocks).
    Then flatten the array to dimensions nxn, H, W, Color if it is a color image. """
    img = np.asarray(img, dtype=np.float32)

    def helper(img, n_blockss):
        # Find the proper padding
        if img.shape[1]%n_blockss != 0:
            img = np.pad(img, ((0, 0), (n_blockss - img.shape[1] % n_blockss, 0)), mode = 'constant')
        if img.shape[0]%n_blockss != 0:
            img = np.pad(img, ((n_blockss - img.shape[0] % n_blockss, 0), (0, 0)), mode = 'constant')
        horizontal = np.array_split(img, n_blockss)
        horizontal = np.asarray(horizontal, dtype=img.dtype)
        splitted_img = [np.array_split(bloc, n_blockss, axis=1) for bloc in horizontal]
        splitted_img = np.asarray(splitted_img, dtype=img.dtype)
        splitted_img = np.reshape(splitted_img, (n_blockss*n_blockss, splitted_img.shape[2], splitted_img.shape[3]))
        return splitted_img

    if n_blocks != 1: # If we blocks is not none i.e. we asked for blocks
        if img.shape[-1] != 3:
            splitted_img = helper(img, n_blocks)
        elif img.shape[-1] == 3:
            splitted_img_b = img[:,:,0]
            splitted_img_g = img[:,:,1]
            splitted_img_r = img[:,:,2]
            splitted_img_b = np.asarray(splitted_img_b)
            splitted_img_g = np.asarray(splitted_img_g)
            splitted_img_r = np.asarray(splitted_img_r)
            splitted_img_b = helper(splitted_img_b, n_blocks)
            splitted_img_g = helper(splitted_img_g, n_blocks)
            splitted_img_r = helper(splitted_img_r, n_blocks)
            splitted_img = np.stack((splitted_img_b, splitted_img_g, splitted_img_r), axis=3)
            np.asarray(splitted_img)
    else:
        splitted_img = img
    return splitted_img


def unblock(array, root_img):
    """Merge nxn blocks of images fed as rows of arrays back into the root image image"""
    orig_image = []
    if array.shape[-1] != 3: # Gray image
        if len(array.shape) == 3: # Do we have blocks or not
            array = np.reshape(array,(int(np.sqrt(array.shape[0])),int(np.sqrt(array.shape[0])), array.shape[1], array.shape[2]))
            orig_image = np.concatenate(array,axis=-2)
            orig_image = np.concatenate(orig_image,axis=-1)
            orig_image = orig_image[(orig_image.shape[0] - root_img.shape[0]):, (orig_image.shape[1] - root_img.shape[1]):] # Padding removal
        elif len(array.shape) == 2: # Do we have blocks or not
            orig_image = array
    if array.shape[-1] == 3:
        if len(array.shape) == 4: # Do we have blocks or not
            array = np.reshape(array,(int(np.sqrt(array.shape[0])), int(np.sqrt(array.shape[0])), array.shape[1], array.shape[2], array.shape[3]))
            orig_image = np.concatenate(array,axis=-3)
            orig_image = np.concatenate(orig_image,axis=-2)
            orig_image = orig_image[(orig_image.shape[0] - root_img.shape[0]):, (orig_image.shape[1] - root_img.shape[1]):, :]# Padding removal
        elif len(array.shape) == 3: # Do we have blocks or not
            orig_image = array
    orig_image = np.asarray(orig_image)
    return orig_image.astype(array.dtype)


def block_old(img, n_blocks):
    """Divide an image to nxn blocks. We pad the image top and left sides with zeros for precises divisions"""
    if img.shape[1]%n_blocks != 0:
        img = np.pad(img,((0,0),(n_blocks - img.shape[1]%n_blocks ,0)), mode = 'constant')
    if img.shape[0]%n_blocks != 0:
        img = np.pad(img,((n_blocks - img.shape[0]%n_blocks,0),(0,0)), mode = 'constant')
    horizontal = np.array_split(img, n_blocks)
    splitted_img = [np.array_split(bloc, n_blocks, axis=1) for bloc in horizontal]
    return np.asarray(splitted_img, dtype=np.float64)


def unblock_old(array):
    """Merge blocks of images with equal dimensions back into one image"""
    orig_image = np.concatenate(array,axis=-2)
    orig_image = np.concatenate(orig_image,axis=-1)
    return orig_image.astype(np.float64)
############################################# NORMALIZATION ############################################################

def normalize(array):
    """Log scale and normalize data to [0-1] and then [0, 255]"""

    #out = 255*(array/np.max(array))
    array = np.log10(array)
    out = 255*(array - np.min(array))/(np.max(array) - np.min(array))
    out = np.asarray(out, dtype=np.uint8)
    return out

def normalize2(array):
    """Normalize data to [0-1] and then [0, 255]"""


    out = 255*(array - np.min(array))/(np.max(array) - np.min(array))
    out = np.asarray(out, dtype=np.uint8)
    return out



########################################### MAIN #############################################################

#
# img = cv2.imread("../Morphing/ours/luuk/smart_20220623_150247.jpg")
# img2 = img[1155:,:,:]
# cv2.imwrite("../Morphing/ours/luuk/smart_20220623_150247.png",img2)
inter = [cv2.INTER_NEAREST , cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]
intername = ['Nearest Neighbor', 'Linear', 'Cubic', 'Lanczos']

#path = "./temp_eurecom/"

# for im in os.listdir(path):
#     if os.path.isfile(path + im):#if not any(z in im for z in ('fig','rot','comp','dscl', 'uscl','duscl', 'udscl', 'sharp')):
#         image = cv2.imread(path + im,0)
