import cv2
import matplotlib.pyplot as plt
import numpy as np
import shutil
"""ela either in whole image or blocks. """
from misc import *



# These are not used in the script
def block(img, n_blocks=1):
    """Divide an image to nxn blocks. We pad the image top and left sides with zeros for precises divisions (So that H,W is a multiple of number of blocks).
    Then flatten the array to dimensions nxn, H, W, Color if it is a color image. """
    img = np.asarray(img, dtype=np.float32)
    print("Image before processing shape and type", img.shape, img.dtype)

    def helper(img, n_blockss):
        # Find the proper padding
        if img.shape[1]%n_blockss != 0:
            img = np.pad(img, ((0, 0), (n_blockss - img.shape[1] % n_blockss, 0)), mode = 'constant')
        if img.shape[0]%n_blockss != 0:
            img = np.pad(img, ((n_blockss - img.shape[0] % n_blockss, 0), (0, 0)), mode = 'constant')
        horizontal = np.array_split(img, n_blockss)
        horizontal = np.asarray(horizontal, dtype=img.dtype)
        splitted_img = [np.array_split(bloc, n_blockss, axis=1) for bloc in horizontal]
        splitted_img = np.asarray(splitted_img, dtype=img.dtype)
        splitted_img = np.reshape(splitted_img, (n_blockss*n_blockss, splitted_img.shape[2], splitted_img.shape[3]))
        return splitted_img

    if n_blocks != 1: # If we blocks is not none i.e. we asked for blocks
        if img.shape[-1] != 3:
            splitted_img = helper(img, n_blocks)
        elif img.shape[-1] == 3:
            (splitted_img_b, splitted_img_g, splitted_img_r) = cv2.split(img)
            splitted_img_b = np.asarray(splitted_img_b)
            splitted_img_g = np.asarray(splitted_img_g)
            splitted_img_r = np.asarray(splitted_img_r)
            splitted_img_b = helper(splitted_img_b, n_blocks)
            splitted_img_g = helper(splitted_img_g, n_blocks)
            splitted_img_r = helper(splitted_img_r, n_blocks)
            splitted_img = np.stack((splitted_img_b, splitted_img_g, splitted_img_r), axis=3)
            np.asarray(splitted_img)
    else:
        splitted_img = img
    return splitted_img


def unblock(array):
    """Merge nxn blocks of images fed as rows of arrays back into one image"""
    orig_image = []
    if array.shape[-1] != 3: # Gray image
        if len(array.shape) == 3: # Do we have blocks or not
            array = np.reshape(array,(int(np.sqrt(array.shape[0])),int(np.sqrt(array.shape[0])), array.shape[1], array.shape[2]))
            orig_image = np.concatenate(array,axis=-2)
            orig_image = np.concatenate(orig_image,axis=-1)
            orig_image = orig_image[(orig_image.shape[0] - root_img.shape[0]):, (orig_image.shape[1] - root_img.shape[1]):]
        elif len(array.shape) == 2: # Do we have blocks or not
            orig_image = array
    if array.shape[-1] == 3:
        if len(array.shape) == 4: # Do we have blocks or not
            array = np.reshape(array,(int(np.sqrt(array.shape[0])), int(np.sqrt(array.shape[0])), array.shape[1], array.shape[2], array.shape[3]))
            orig_image = np.concatenate(array,axis=-3)
            orig_image = np.concatenate(orig_image,axis=-2)
            orig_image = orig_image[(orig_image.shape[0] - root_img.shape[0]):, (orig_image.shape[1] - root_img.shape[1]):, :]
        elif len(array.shape) == 3: # Do we have blocks or not
            orig_image = array
    return orig_image.astype(array.dtype)


# Error level analysis
def ela(image):
    #image = np.asarray(image, dtype=np.int8)
    list_of_ela = []
    list_of_measures = []
    for f, cf in enumerate(compress_factors):
        comp_image = np.empty(shape=image.shape,dtype=image.dtype)
        if image.shape[-1] == 3:
            cv2.imwrite((path + os.path.splitext(im_path)[0] + '/' + os.path.splitext(im_path)[0] + "_cf_%s.jpg" % cf), image, [int(cv2.IMWRITE_JPEG_QUALITY), cf])
            comp_image = cv2.imread(path + os.path.splitext(im_path)[0] + '/' + os.path.splitext(im_path)[0] + "_cf_%s.jpg" % cf)
     #       comp_image = np.asarray(comp_image, dtype=np.int8)
        elif image.shape[-1] != 3:
            cv2.imwrite((path + os.path.splitext(im_path)[0] + '/' + os.path.splitext(im_path)[0] + "_cf_%s.jpg" % cf), image, [int(cv2.IMWRITE_JPEG_QUALITY), cf])
            comp_image = cv2.imread(path + os.path.splitext(im_path)[0] + '/' + os.path.splitext(im_path)[0] + "_cf_%s.jpg" % cf,0)
      #      comp_image = np.asarray(comp_image, dtype=np.int8)

        # image = cv2.cvtColor(image,cv2.COLOR_BGR2YCrCb)
        # comp_image = cv2.cvtColor(comp_image,cv2.COLOR_BGR2YCrCb)

	# Below is the residual image or the absolute compression error
        temp = np.abs(image - comp_image)

        # You can play with which values you are gonna map to 255 or to anything else to make the visualization better
        temp[temp>0] = 255

        temp = np.asarray(temp,dtype='uint8')


        list_of_ela.append(temp)
        # Measures are the features i am extracting for classification. Not necessary for visualization
        list_of_measures.append(get_measures(image, comp_image, thres=2))

    fig = plt.figure()
    fig.suptitle(im_path,size=3)

    gs = gridspec.GridSpec(1, len(compress_factors) + 1, figure=fig)

    for i, cf in enumerate(compress_factors):
        ax = fig.add_subplot(gs[0, i]).set_title("Com. factor %s\nHamming=%.4f\nThres_diff=%.4f\nMean_abs_dif=%.4f"
                                                 "\npsnr=%.4f\nssim=%.4f" % (cf,list_of_measures[i][0],
                                                                                           list_of_measures[i][1],
                                                                                           list_of_measures[i][2],
                                                                                           list_of_measures[i][3],
                                                                                           list_of_measures[i][4],
                                                                                           ),size=3)
        plt.imshow(cv2.cvtColor(list_of_ela[i], cv2.COLOR_BGR2RGB), cmap="gray")
        plt.axis('off')



    #plt.savefig(path + os.path.splitext(im_path)[0] + '/ela_plot_' + str(n_blocks) +'_' + os.path.splitext(im_path)[0] + '.png', dpi=700)
    plt.savefig(path + os.path.splitext(im_path)[0] + '/' + prefix + os.path.splitext(im_path)[0] + ".png", dpi=1500,)

    np.savetxt(path + os.path.splitext(im_path)[0] + '/' + "measures_ela_" + os.path.splitext(im_path)[0] + ".txt", list_of_measures)
    #plt.show()
    plt.close()

    # for cf in compress_factors:
    #     os.remove((path + os.path.splitext(im_path)[0] + '/' + os.path.splitext(im_path)[0] + "_cf_%s.jpg" % cf))

######################################################################################################################

######################################################################################################################

# Initialize dlib's face detector and then create the facial landmark predictor
#hog_detector = dlib.get_frontal_face_detector()
# Enter the path of dlib's shape predictor model
#extractor = dlib.shape_predictor('./shape_predictor_68_face_landmarks_GTX.dat')


#["./temp/test/"]
paths = ["./temp3/ajax/"]

compress_factors = [100,99,98,97,96,95,94,93]#,92,91,90,89,88,87,86,85,84,83,82,81,80]
#compress_factors = [40,39,38,37,36,35,34,33,32,31,30]


for path in paths:
    for im_path in os.listdir(path):
        if os.path.isfile(path + im_path):
            os.makedirs(path + os.path.splitext(im_path)[0], exist_ok=True)
            control = os.listdir(path + os.path.splitext(im_path)[0]+"/")
            prefix = "fig_ela_"
            if prefix + os.path.splitext(im_path)[0]+".png" not in control:
                if not any(z in im_path for z in ('fig','measures')):
                    print("Processing %s ..." % im_path)
                    os.makedirs(path + os.path.splitext(im_path)[0], exist_ok=True)
                    shutil.copy(path + im_path, path + os.path.splitext(im_path)[0] + '/' + im_path)
                    root_img = cv2.imread(path + im_path,-1)
                    #root_img = cv2.rotate(root_img,cv2.ROTATE_90_CLOCKWISE)
                    print(root_img.dtype)
                    #root_img = cv2.cvtColor(root_img,cv2.COLOR_BGR2YCrCb)
                    #root_img = root_img[:,:,0]

                    # root_img = cv2.medianBlur(root_img,5)
                    # root_img = sharpen(root_img,1,0.9)
                    # cv2.imshow("z",root_img)
                    #
                    # cv2.waitKey()
                    #root_img = np.asarray(root_img,dtype="uint8")
                    #root_img = np.asarray(root_img,dtype="int16")
                    #ela(crop_face_only_bigger(root_img))
                    #root_img = sharpen(root_img,3,0.7)
                    ela(root_img)



