import cv2
import matplotlib.pyplot as plt
import numpy as np
from imutils.face_utils import FaceAligner
from misc2 import *


def gbp_filter(target, sigma1, sigma2, power=30, reverse=False):
    """Bandpass guassian-like filter. Allows only freqs between sigma1 and sigma 2.
    Target: target image
        Cuttoff_sigma1: the inner cuttoff point. Usually taken as percentage of the half-diagonal of the image.
        Cuttoff_sigma2: the outer cuttoff point. sigma1 > sigma2
        Power: The power of the exponential. Bigger means pdf is steeper.
        Reverse: If True it will reverse the filter. It will cutoff the freqs between sigma1 and sigma2
         This is an approximation. Custom filter"""
    tic()
    center = (target.shape[1]//2, target.shape[0]//2)
    sigma1 = sigma1*np.sqrt(center[0]**2 + center[1]**2)
    sigma2 = sigma2*np.sqrt(center[0]**2 + center[1]**2)
    x, y = (target.shape[1]//2, target.shape[0]//2) # center
    grid = np.array([[((i**2+j**2) / (1.0 * sigma1 ** 2)) for i in range(-x, x + 1)] for j in range(-y, y + 1)])
    # Check dimensions. Cut 1st row/column either from the filter or the target
    if grid.shape[0] > target.shape[0]:
        grid = grid[1:,:]
    elif grid.shape[0] < target.shape[0]:
        target = target[1:,:]
    if grid.shape[1] > target.shape[1]:
        grid = grid[:,1:]
    elif grid.shape[1] < target.shape[1]:
        target = target[:,1:]
    s1_filter = np.exp(-grid**power)/(2*np.pi*sigma1**2)
    s1_filter /= np.max(s1_filter) # Set values [0-1]

    x, y = (target.shape[1]//2, target.shape[0]//2) # center
    grid = np.array([[((i**2+j**2) / (1.0 * sigma2 ** 2)) for i in range(-x, x + 1)] for j in range(-y, y + 1)])
    # Check dimensions. Cut 1st row/column either from the filter or the target
    if grid.shape[0] > target.shape[0]:
        grid = grid[1:,:]
    elif grid.shape[0] < target.shape[0]:
        target = target[1:,:]
    if grid.shape[1] > target.shape[1]:
        grid = grid[:,1:]
    elif grid.shape[1] < target.shape[1]:
        target = target[:,1:]
    s2_filter = np.exp(-grid**power)/(2*np.pi*sigma2**2)
    s2_filter /= np.max(s2_filter) # Set values [0-1]
    s2_filter = -1*(s2_filter-1)
    g_filter = s1_filter*s2_filter
    if reverse is True:
        g_filter = -1*(g_filter-1)

    if target.shape[-1] !=3: # Check if it RGB or gray
        filtered_magnitude = target*g_filter
    elif target.shape[-1] == 3:
        filtered_magnitude_b = target[:,:,0]*g_filter
        filtered_magnitude_g = target[:,:,1]*g_filter
        filtered_magnitude_r = target[:,:,2]*g_filter
        filtered_magnitude = np.dstack((filtered_magnitude_b,filtered_magnitude_g,filtered_magnitude_r))

    tac()
    return filtered_magnitude


def ghp_filter(target, sigma, power=30, hp=True):
    """Target: target image
        Sigma: the filter size. Usually taken as percentage of the half-diagonal of the image.
        If sigma is equal to image_width/2 it will likely .
        Power: The power of the exponential. Bigger means pdf is steeper.
        True for high pass, False for low pass.
         This is an approximation. Custom filter"""
    print(sigma)
    center = (target.shape[1]//2, target.shape[0]//2)
    sigma = sigma*np.sqrt(center[0]**2 + center[1]**2)
    x, y = (target.shape[1]//2, target.shape[0]//2) # center
    grid = np.array([[((i**2+j**2)/(1.0*sigma**2)) for i in range(-x, x+1)] for j in range(-y, y+1)])
    # plt.imshow(grid,cmap="gray")
    # plt.show()

    # Check dimensions. Cut 1st row/column either from the filter or the target
    if grid.shape[0] > target.shape[0]:
        grid = grid[1:,:]
    elif grid.shape[0] < target.shape[0]:
        target = target[1:,:]
    if grid.shape[1] > target.shape[1]:
        grid = grid[:,1:]
    elif grid.shape[1] < target.shape[1]:
        target = target[:,1:]
    #g_filter = np.exp(grid**45)/(2*np.pi*sigma**2)
    g_filter = np.exp(-grid**power)/(2*np.pi*sigma**2)

    # plt.imshow(g_filter,cmap="gray")
    # plt.show()
    g_filter /= np.max(g_filter) # Set values [0-1]
    # plt.imshow(g_filter,cmap="gray")
    # plt.show()

    if hp is True:
        g_filter = -1*(g_filter-1)

    if target.shape[-1] !=3: # Check if it RGB or gray
        filtered_magnitude = target*g_filter
    elif target.shape[-1] == 3:
        filtered_magnitude_b = target[:,:,0]*g_filter
        filtered_magnitude_g = target[:,:,1]*g_filter
        filtered_magnitude_r = target[:,:,2]*g_filter
        filtered_magnitude = np.dstack((filtered_magnitude_b,filtered_magnitude_g,filtered_magnitude_r))

    return filtered_magnitude


def block(img, n_blocks=1):
    """Divide an image to nxn blocks. We pad the image top and left sides with zeros for precises divisions (So that H,W is a multiple of number of blocks).
    Then flatten the array to dimensions nxn, H, W, Color if it is a color image. """
    img = np.asarray(img, dtype=np.float32)
    print("Image before processing shape and type", img.shape, img.dtype)

    def helper(img, n_blockss):
        # Find the proper padding
        if img.shape[1]%n_blockss != 0:
            img = np.pad(img, ((0, 0), (n_blockss - img.shape[1] % n_blockss, 0)), mode = 'constant')
        if img.shape[0]%n_blockss != 0:
            img = np.pad(img, ((n_blockss - img.shape[0] % n_blockss, 0), (0, 0)), mode = 'constant')
        horizontal = np.array_split(img, n_blockss)
        horizontal = np.asarray(horizontal, dtype=img.dtype)
        splitted_img = [np.array_split(bloc, n_blockss, axis=1) for bloc in horizontal]
        splitted_img = np.asarray(splitted_img, dtype=img.dtype)
        splitted_img = np.reshape(splitted_img, (n_blockss*n_blockss, splitted_img.shape[2], splitted_img.shape[3]))
        return splitted_img

    if n_blocks != 1: # If we blocks is not none i.e. we asked for blocks
        if img.shape[-1] != 3:
            splitted_img = helper(img, n_blocks)
        elif img.shape[-1] == 3:
            (splitted_img_b, splitted_img_g, splitted_img_r) = cv2.split(img)
            splitted_img_b = np.asarray(splitted_img_b)
            splitted_img_g = np.asarray(splitted_img_g)
            splitted_img_r = np.asarray(splitted_img_r)
            splitted_img_b = helper(splitted_img_b, n_blocks)
            splitted_img_g = helper(splitted_img_g, n_blocks)
            splitted_img_r = helper(splitted_img_r, n_blocks)
            splitted_img = np.stack((splitted_img_b, splitted_img_g, splitted_img_r), axis=3)
            np.asarray(splitted_img)
    else:
        splitted_img = img
    return splitted_img


def unblock(array):
    """Merge nxn blocks of images fed as rows of arrays back into one image"""
    orig_image = []
    if array.shape[-1] != 3: # Gray image
        if len(array.shape) == 3: # Do we have blocks or not
            array = np.reshape(array,(int(np.sqrt(array.shape[0])),int(np.sqrt(array.shape[0])), array.shape[1], array.shape[2]))
            orig_image = np.concatenate(array,axis=-2)
            orig_image = np.concatenate(orig_image,axis=-1)
            orig_image = orig_image[(orig_image.shape[0] - root_img.shape[0]):, (orig_image.shape[1] - root_img.shape[1]):] # Padding removal
        elif len(array.shape) == 2: # Do we have blocks or not
            orig_image = array
    if array.shape[-1] == 3:
        if len(array.shape) == 4: # Do we have blocks or not
            array = np.reshape(array,(int(np.sqrt(array.shape[0])), int(np.sqrt(array.shape[0])), array.shape[1], array.shape[2], array.shape[3]))
            orig_image = np.concatenate(array,axis=-3)
            orig_image = np.concatenate(orig_image,axis=-2)
            orig_image = orig_image[(orig_image.shape[0] - root_img.shape[0]):, (orig_image.shape[1] - root_img.shape[1]):, :]# Padding removal
        elif len(array.shape) == 3: # Do we have blocks or not
            orig_image = array
    orig_image = np.asarray(orig_image)
    return orig_image.astype(array.dtype)


def undo_fft2(magnitude, phase, norm):
    """Reconstruct an image from frequency domain given the magnitude spectrum matrix and the phase matrix"""

    if magnitude.shape[-1] != 3:# recon_matrix = np.empty(magnitude.shape, dtype=complex)
        recon_matrix = magnitude*np.exp(1j*phase)
        recon_matrix = np.fft.ifftshift(recon_matrix)
        recon_im = np.abs(np.fft.ifft2(recon_matrix))
        if norm == True:
            recon_im = normalize(recon_im) # This is for the noise image
        else:
            recon_im = np.round(recon_im) # this is for the reconstructed image
            recon_im = np.asarray(recon_im, dtype=np.uint8)
        return recon_im

    elif magnitude.shape[-1] == 3:
        recon_matrix_b = magnitude[:,:,0]*np.exp(1j*phase[:,:,0])
        recon_matrix_b = np.fft.ifftshift(recon_matrix_b)
        recon_im_b = np.abs(np.fft.ifft2(recon_matrix_b))
        if norm ==True:
            recon_im_b = normalize(recon_im_b)
        else:
            normalize2(recon_im_b)
            # plt.imshow(recon_im_b)
            # plt.show()

        recon_matrix_g = magnitude[:,:,1]*np.exp(1j*phase[:,:,1])
        recon_matrix_g = np.fft.ifftshift(recon_matrix_g)
        recon_im_g = np.abs(np.fft.ifft2(recon_matrix_g))
        if norm == True:
            recon_im_g = normalize(recon_im_g)
        else:
            normalize2(recon_im_g)
            # plt.imshow(recon_im_g)
            # plt.show()

        recon_matrix_r = magnitude[:,:,2]*np.exp(1j*phase[:,:,2])
        recon_matrix_r = np.fft.ifftshift(recon_matrix_r)
        recon_im_r = np.abs(np.fft.ifft2(recon_matrix_r))
        if norm == True:
            recon_im_r = normalize(recon_im_r)
        else:
            normalize2(recon_im_r)
            # plt.imshow(recon_im_r)
            # plt.show()

        recon_im = np.dstack((recon_im_b,recon_im_g,recon_im_r))
        return recon_im


def do_fft2(image):
    """Return magnitudes and phases  """
    if image.shape[-1] != 3:
        fimage = np.fft.fft2(image)
        fimage = np.fft.fftshift(fimage)
        phase = np.angle(fimage)
        magnitude = np.abs(fimage)
        # plt.imshow(np.log(1+magnitude))
        # plt.colorbar()
        # plt.show()
        return magnitude, phase
    elif image.shape[-1] == 3:
        fimage_b = np.fft.fft2(image[:,:,0])
        fimage_b = np.fft.fftshift(fimage_b)
        phase_b = np.angle(fimage_b)
        magnitude_b = np.abs(fimage_b)
        fimage_g = np.fft.fft2(image[:,:,1])
        fimage_g = np.fft.fftshift(fimage_g)
        phase_g = np.angle(fimage_g)
        magnitude_g = np.abs(fimage_g)
        fimage_r = np.fft.fft2(image[:,:,2])
        fimage_r = np.fft.fftshift(fimage_r)
        phase_r = np.angle(fimage_r)
        magnitude_r = np.abs(fimage_r)
        magnitude = np.dstack((magnitude_b, magnitude_g, magnitude_r))
        phase = np.dstack((phase_b, phase_g, phase_r))
        return magnitude, phase


def analyze(image):
    list_of_noise_mags = []
    list_of_noise_recon = []
    list_of_measures = []

    orig_magnitude, phase = do_fft2(image)

    for f in filter_size_factor:
        #noise_mag = ghp_filter(orig_magnitude,sigma=f)
        noise_mag = ghp_filter(orig_magnitude,sigma=f,power=30)
        noise_recon = undo_fft2(noise_mag, phase, norm=True) # true is rounding to integer
        list_of_noise_mags.append(noise_mag)
        list_of_noise_recon.append(noise_recon)



    # for f1,f2 in filter_size_factor:
    #     noise_mag = gbp_filter(orig_magnitude,sigma1=f1,sigma2=f2)
    #
    #     noise_recon = undo_fft2(noise_mag, phase, norm=True) # true is rounding to integer
    #     list_of_noise_mags.append(noise_mag)
    #     list_of_noise_recon.append(noise_recon)

    fig = plt.figure()
    #fig.suptitle(str(name),size=5)

    gs = gridspec.GridSpec(4, len(filter_size_factor) + 1, figure=fig)

    for i, f in enumerate(filter_size_factor):
        gs.update(wspace=0.1, hspace=0)
        y = list_of_noise_recon[i][:,:,0]
        cb = list_of_noise_recon[i][:,:,1]
        cr = list_of_noise_recon[i][:,:,2]


        ### Threshold values. Different ways to of thresholding/mapping for better visualization
        # 0
        # y = y/np.max(y)
        # cb = cb/np.max(cb)
        # cr = cr/np.max(cr)

        # 1
        # y[y>(np.mean(y)+5)] = 255
        # y[y<=(np.mean(y))] = 0
        # cb[cb>(np.mean(cb)+5)] = 255
        # cb[cb<=(np.mean(cb))] = 0
        # cr[cr>(np.mean(cr)+5)] = 255
        # cr[cr<=(np.mean(cr))] = 0

        #2
        # y[y>(np.mean(y))] = 255
        # y[y<=(np.mean(y))] = 0
        # cb[cb>(np.mean(cb))] = 255
        # cb[cb<=(np.mean(cb))] = 0
        # cr[cr>(np.mean(cr))] = 255
        # cr[cr<=(np.mean(cr))] = 0

        # y[y>(np.mean(y))] = 255
        # cb[cb>(np.mean(cb))] = 255
        # cr[cr>(np.mean(cr))] = 255

        # y[y<=(np.mean(y))] = 0
        # cb[cb<=(np.mean(cb))] = 0
        # cr[cr<=(np.mean(cr))] = 0

        ### Figure
        ax0 = fig.add_subplot(gs[0,i])
        plt.imshow(np.log(1+list_of_noise_mags[i][:,:,0]), cmap="gray") # try log here
        plt.axis('off')

        ax1 = fig.add_subplot(gs[1,i])
        plt.imshow(y, cmap="gray_r")
        plt.axis('off')

        # ax2 = fig.add_subplot(gs[2,i])
        # plt.imshow(np.log(1+list_of_noise_mags[i][:,:,1]), cmap="gray") # try log here
        # plt.axis('off')

        ax3 = fig.add_subplot(gs[2,i])
        plt.imshow(cb, cmap="gray_r")
        plt.axis('off')

        # ax4 = fig.add_subplot(gs[4,i])
        # plt.imshow(np.log(1+list_of_noise_mags[i][:,:,2]), cmap="gray") # try log here
        # plt.axis('off')

        ax5 = fig.add_subplot(gs[3,i])
        plt.imshow(cr, cmap="gray_r")
        plt.axis('off')

    plt.savefig(path + os.path.splitext(im)[0] + '/' + prefix + os.path.splitext(im)[0] + ".png", dpi=1500)
    #plt.show()
    plt.close()
    #np.savetxt(path + os.path.splitext(im)[0] + '/' + "measures_fft_" + os.path.splitext(im)[0] + ".txt", list_of_measures)

    return





paths = ["./ajax_morphs/dd/"]


# from lower to higher HF cut-off
#filter_size_factor = [0.8,0.7,0.6,0.5,0.4,0.3]
filter_size_factor = [0.7,0.6,0.4,0.3]

for path in paths:
    for im in os.listdir(path):
        if os.path.isfile(path + im):
            os.makedirs(path + os.path.splitext(im)[0], exist_ok=True)
            control = os.listdir(path + os.path.splitext(im)[0] + "/")
            prefix = "fig_fft_p30_nocrop_nothres_"
            if prefix + os.path.splitext(im)[0] + ".png" not in control:
                print(im)
                imgg = cv2.imread(path + im)

                #global root_img
                root_img = imgg
                #root_img = crop_face_only_bigger(root_img)
                root_img = cv2.cvtColor(root_img,cv2.COLOR_BGR2YCrCb)

                #root_img = cv2.cvtColor(root_img,cv2.COLOR_BGR2YCrCb)

                name = im

                print("image shape %s x %s"%(root_img.shape[0],root_img.shape[1]))
                print("Processing: %s" % im)
                analyze(root_img)
