import cv2
import os
from os import listdir
import numpy as np
import time


def tic():
    global _start_time
    _start_time = time.time()


def tac():
    t_sec = round(time.time() - _start_time)
    (t_min, t_sec) = divmod(t_sec,60)
    (t_hour,t_min) = divmod(t_min,60)
    print('Time passed: {}min:{}sec'.format(t_min,t_sec))

def readPoints(path):
    # Create an array of points.
    points = []
    # Read points
    with open(path) as file:
        for line in file:
            x, y = line.split()
            points.append((int(x), int(y)))
    return points


# Check if a point is inside a rectangle
def rect_contains(rect, point):
    if point[0] < rect[0] :
        return False
    elif point[1] < rect[1] :
        return False
    elif point[0] > rect[2] :
        return False
    elif point[1] > rect[3] :
        return False
    return True


# Sharpening
def sharpen(image, sigma=1, strength=3, kernel='default'):
    if kernel == 'default':
        selected_kernel = np.array(([[0, 1, 0], [1, -4, 1], [0, 1, 0]]))#,np.float32)/9
    elif kernel == 'extreme':
        selected_kernel = np.array(([[1, 1, 1], [1, -8, 1], [1, 1, 1]]))#,np.float32)/7

    sharp_full = np.zeros_like(image)
    for i in range(3):
        temp = image[:,:,i]
        blurred = cv2.medianBlur(temp, sigma)
        filtered = cv2.filter2D(src=blurred, kernel=selected_kernel, ddepth=cv2.CV_64F, borderType=cv2.BORDER_REFLECT_101)
        sharp = temp - strength * filtered
        sharp[sharp>255] = 255
        sharp[sharp<0] = 0
        sharp_full[:,:,i] = sharp

    return sharp_full

# Sharpening
def sharpen(image, sigma=1, strength=3, kernel='default'):
    if kernel == 'default':
        selected_kernel = np.array(([[0, 1, 0], [1, -4, 1], [0, 1, 0]]))#,np.float32)/9
    elif kernel == 'extreme':
        selected_kernel = np.array(([[1, 1, 1], [1, -8, 1], [1, 1, 1]]))#,np.float32)/7

    sharp_full = np.zeros_like(image)
    for i in range(3):
        temp = image[:,:,i]
        blurred = cv2.medianBlur(temp, sigma)
        filtered = cv2.filter2D(src=blurred, kernel=selected_kernel, ddepth=cv2.CV_64F, borderType=cv2.BORDER_REFLECT_101)
        sharp = temp - strength * filtered
        sharp[sharp>255] = 255
        sharp[sharp<0] = 0
        sharp_full[:,:,i] = sharp

    return sharp_full


def swap(source_img, target_img, points):

    # Read source image
    img11 = cv2.imread(source_img)
    ############### SHARP #############
    img11 = sharpen(img11,3,0.3)
    #img11 = cv2.medianBlur(img11,5)


    # Read points. Here the boundary points don't matter. We just want the inner face
    shape = readPoints(points)
    shape1 = shape[0:11]
    # shape1.extend(shape[71:73])
    shape1.extend(shape[21:23])
    shape1.extend(shape[11:13])
    shape1.extend(shape[17:19])
    shape1.extend(shape[25:27])
    shape = np.int64(shape1)

    # Read target image
    img22 = cv2.imread(target_img)
    ############### SHARP #############
    img22 = sharpen(img22, 3, 0.15)

    img1Warped = np.copy(img11)

    # Find convex hull
    hull = []
    # hullIndex is a vector of indices of points
    # that form the convex hull.
    hullIndex2 = cv2.convexHull(np.array(shape), returnPoints=False)
    for i in range(0, len(hullIndex2)):
        hull.append(shape[int(hullIndex2[i])])

    # Calculate Mask
    hull8U = []
    for i in range(0, len(hull)):
        hull8U.append((hull[i][0], hull[i][1]))
    mask = np.zeros(img22.shape, dtype=img22.dtype)
    cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255))

    r = cv2.boundingRect(np.float32([hull]))
    center = ((r[0]+int(r[2]/2)), (r[1]+int(r[3]/2)))

    # Clone seamlessly. MIXED_CLONE vs NORMAL_CLONE. In normal clone mode the transferred area more or less maintains its texture, whereas in mixed clone mode it is heavily modified to match the target texture
    #cv2.imwrite(path + p + "/mask.png",cv2.cvtColor(mask*np.uint8(img1Warped),cv2.COLOR_BGR2RGB))
    output = cv2.seamlessClone(np.uint8(img1Warped), img22, mask, center, cv2.NORMAL_CLONE)# MIXED_CLONE vs NORMAL_CLONE
    return output


def swap_nose(source_img, target_img, points):

    # Read source image
    img11 = cv2.imread(target_img)
    ############### SHARP #############
    img11 = sharpen(img11,1,0.15)

    # Read points
    shape = readPoints(points)
    shape = shape[49:56]
    shape = np.int64(shape)

    # Read target image
    img22 = cv2.imread(source_img)
    ############### SHARP #############
    #img22 = sharpen(img22, 1, 0.7)

    img1Warped = np.copy(img11)

    # Find convex hull
    hull = []
    # hullIndex is a vector of indices of points
    # that form the convex hull.
    hullIndex2 = cv2.convexHull(np.array(shape), returnPoints=False)
    for i in range(0, len(hullIndex2)):
        hull.append(shape[int(hullIndex2[i])])

    # Calculate Mask
    hull8U = []
    for i in range(0, len(hull)):
        hull8U.append((hull[i][0], hull[i][1]))
    mask = np.zeros(img22.shape, dtype=img22.dtype)
    cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255))

    r = cv2.boundingRect(np.float32([hull]))
    center = ((r[0]+int(r[2]/2)), (r[1]+int(r[3]/2)))

    # Clone seamlessly. MIXED_CLONE vs NORMAL_CLONE. In normal clone mode the transferred area more or less maintains its texture, whereas in mixed clone mode it is heavily modified to match the target texture
    output = cv2.seamlessClone(np.uint8(img1Warped), img22, mask, center, cv2.NORMAL_CLONE)# MIXED_CLONE vs NORMAL_CLONE
    return output


def swap_eyes1(source_img, target_img, points):

    # Read source image
    img11 = cv2.imread(target_img)
    ############### SHARP #############
    img11 = sharpen(img11,1,0.15)

    # Read points
    shape = readPoints(points)
    shape1 = shape[13:16]
    shape1.extend(shape[32:35])
    shape1.extend(shape[29:30])
    shape1.extend(shape[21:22])

    shape = np.int64(shape1)

    # Read target image
    img22 = cv2.imread(source_img)
    ############### SHARP #############
    #img22 = sharpen(img22, 1, 0.7)

    img1Warped = np.copy(img11)

    # Find convex hull
    hull = []
    # hullIndex is a vector of indices of points
    # that form the convex hull.
    hullIndex2 = cv2.convexHull(np.array(shape), returnPoints=False)
    for i in range(0, len(hullIndex2)):
        hull.append(shape[int(hullIndex2[i])])

    # Calculate Mask
    hull8U = []
    for i in range(0, len(hull)):
        hull8U.append((hull[i][0], hull[i][1]))
    mask = np.zeros(img22.shape, dtype=img22.dtype)
    cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255))

    r = cv2.boundingRect(np.float32([hull]))
    center = ((r[0]+int(r[2]/2)), (r[1]+int(r[3]/2)))

    # Clone seamlessly. MIXED_CLONE vs NORMAL_CLONE. In normal clone mode the transferred area more or less maintains its texture, whereas in mixed clone mode it is heavily modified to match the target texture
    output = cv2.seamlessClone(np.uint8(img1Warped), img22, mask, center, cv2.NORMAL_CLONE)# MIXED_CLONE vs NORMAL_CLONE
    return output


def swap_eyes2(source_img, target_img, points):

    # Read source image
    img11 = cv2.imread(target_img)
    ############### SHARP #############
    img11 = sharpen(img11,1,0.15)

    # Read points
    shape = readPoints(points)
    shape1 = shape[16:17]
    shape1.extend(shape[19:21])
    shape1.extend(shape[26:27])
    shape1.extend(shape[38:41])
    shape1.extend(shape[35:36])

    shape = np.int64(shape1)

    # Read target image
    img22 = cv2.imread(source_img)
    ############### SHARP #############
    #img22 = sharpen(img22, 1, 0.7)

    img1Warped = np.copy(img11)

    # Find convex hull
    hull = []
    # hullIndex is a vector of indices of points
    # that form the convex hull.
    hullIndex2 = cv2.convexHull(np.array(shape), returnPoints=False)
    for i in range(0, len(hullIndex2)):
        hull.append(shape[int(hullIndex2[i])])

    # Calculate Mask
    hull8U = []
    for i in range(0, len(hull)):
        hull8U.append((hull[i][0], hull[i][1]))
    mask = np.zeros(img22.shape, dtype=img22.dtype)
    cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255))

    r = cv2.boundingRect(np.float32([hull]))
    center = ((r[0]+int(r[2]/2)), (r[1]+int(r[3]/2)))

    # Clone seamlessly. MIXED_CLONE vs NORMAL_CLONE. In normal clone mode the transferred area more or less maintains its texture, whereas in mixed clone mode it is heavily modified to match the target texture
    output = cv2.seamlessClone(np.uint8(img1Warped), img22, mask, center, cv2.NORMAL_CLONE)# MIXED_CLONE vs NORMAL_CLONE
    return output


def swap_lips(source_img, target_img, points):

    # Read source image
    img11 = cv2.imread(target_img)

    ############### SHARP #############
    img11 = sharpen(img11,1,0.15)

    # Read points
    shape = readPoints(points)

    shape = shape[56:68]

    shape = np.int64(shape)

    # Read target image
    img22 = cv2.imread(source_img)
    ############### SHARP #############
    #img22 = sharpen(img22, 1, 0.7)

    img1Warped = np.copy(img11)

    # Find convex hull
    hull = []
    # hullIndex is a vector of indices of points
    # that form the convex hull.
    hullIndex2 = cv2.convexHull(np.array(shape), returnPoints=False)
    for i in range(0, len(hullIndex2)):
        hull.append(shape[int(hullIndex2[i])])

    # Calculate Mask
    hull8U = []
    for i in range(0, len(hull)):
        hull8U.append((hull[i][0], hull[i][1]))
    mask = np.zeros(img22.shape, dtype=img22.dtype)
    cv2.fillConvexPoly(mask, np.int32(hull8U), (255, 255, 255))

    r = cv2.boundingRect(np.float32([hull]))
    center = ((r[0]+int(r[2]/2)), (r[1]+int(r[3]/2)))

    # Clone seamlessly. MIXED_CLONE vs NORMAL_CLONE. In normal clone mode the transferred area more or less maintains its texture, whereas in mixed clone mode it is heavily modified to match the target texture
    output = cv2.seamlessClone(np.uint8(img1Warped), img22, mask, center, cv2.NORMAL_CLONE)# MIXED_CLONE vs NORMAL_CLONE
    return output


def swap_all(source, target, warped, points):
    fm1pth = source
    output3_1 = swap(source, target, points)
    cv2.imwrite(fm1pth, output3_1)

    output3_1 = swap_nose(source, warped, points)
    cv2.imwrite(fm1pth, output3_1)

    output3_1 = swap_eyes1(source, warped, points)
    cv2.imwrite(fm1pth, output3_1)

    output3_1 = swap_eyes2(source, warped, points)
    cv2.imwrite(fm1pth, output3_1)

    output3_1 = swap_lips(source, warped, points)
    cv2.imwrite(fm1pth, output3_1)
    #os.remove(source)


if __name__ == '__main__':

    # Set the matching threshold between the morph and the contributors.
    # Morphs that result in both comparisons (morph vs contributor A, morph vs contributor B) being successful qualify.
    extensions = ('.jpg', 'jpeg', 'tiff', 'png', 'JPG')
    extensions2 = ('M_', 'warped')
    # Set the path that contains the pairs
    for path in ["./pairs_imars/"]:

        #path = './pairs/'
        final_dist = []
        for p in listdir(path):
            tic()
            image_path = []
            final_dist.append(p)
            print('Processing : ',p)
            if len(listdir(path + p)) !=5:
                for f in listdir(path + p):
                    if any(ex in f for ex in extensions) and not any(exe in f for exe in extensions2):
                        print('Originals : ', f)
                        image_path.append(f)

                avgpoints = path + p + '/' + 'text' + '/' + 'average.txt'
                # points_a_30 = path + p + '/' + 'text' + '/' + 'a30.txt'
                # points_b_30 = path + p + '/' + 'text' + '/' + 'b30.txt'

                # Get file paths of warped targets
                filenamew_a_50 = path + p + '/' + 'warped_' + os.path.splitext(image_path[0])[0] +'_50.png'
                filenamew_b_50 = path + p + '/' + 'warped_' + os.path.splitext(image_path[1])[0] +'_50.png'
                # filenamew_a_30 = path + p + '/' + 'warped_'+ os.path.splitext(image_path[0])[0] +'_30.png'
                # filenamew_b_30 = path + p + '/' + 'warped_'+ os.path.splitext(image_path[1])[0] +'_30.png'
                # filenamew_a_70 = path + p + '/' + 'warped_'+ os.path.splitext(image_path[0])[0] +'_70.png'
                # filenamew_b_70 = path + p + '/' + 'warped_'+ os.path.splitext(image_path[1])[0] +'_70.png'

                # Get file paths of warped targets
                target_a_50 = path + p + '/' + 'warped_' + os.path.splitext(image_path[0])[0] +'_50.png'
                target_b_50 = path + p + '/' + 'warped_' + os.path.splitext(image_path[1])[0] +'_50.png'
                # target_a_30 = path + p + '/' + 'warped_'+ os.path.splitext(image_path[0])[0] +'_30.png'
                # target_b_30 = path + p + '/' + 'warped_'+ os.path.splitext(image_path[1])[0] +'_30.png'

                # ################################# First candidate ##############################################3
                # filename_a_30 = path + p + '/M_' + os.path.splitext(image_path[0])[0] + '_' + os.path.splitext(image_path[1])[0] + '_UTW2022_B30_W30.png'
                # swap_all(filename_a_30,target_a_30,filenamew_b_70,points_a_30)

                filename_a_50 = path + p + '/M_' + os.path.splitext(image_path[0])[0] + '_' + os.path.splitext(image_path[1])[0] + '_UTW2022_B50_W50.png'
                swap_all(filename_a_50,target_a_50,filenamew_b_50,avgpoints)

                # ################################# First candidate ##############################################3
                # filename_b_30 = path + p + '/M_' + os.path.splitext(image_path[1])[0] + '_' + os.path.splitext(image_path[0])[0] + '_UTW2022_B30_W30.png'
                # swap_all(filename_b_30,target_b_30,filenamew_a_70,points_b_30)

                filename_b_50 = path + p + '/M_' + os.path.splitext(image_path[1])[0] + '_' + os.path.splitext(image_path[0])[0] + '_UTW2022_B50_W50.png'
                swap_all(filename_b_50,target_b_50,filenamew_a_50,avgpoints)


                ############### Delete warped #################
                # for f in listdir(path + p):
                #     check_image = path + p + '/' + f
                #     if 'warped' in check_image:
                #         os.remove(check_image)
            tac()
    print('..................Finished.................')

