import cv2
import numpy as np
from matplotlib import pyplot as plt

def find_shift(img, obj):

    bin_img = 1 - img / 255

##    bin_obj = np.zeros(bin_img.shape)
##    bin_obj[0:obj.shape[0],0:obj.shape[1]] = 1 - obj/255
##    
    bin_obj = np.pad(1 - obj/255, ((0,img.shape[0]-obj.shape[0]),\
                           (0,img.shape[1]-obj.shape[1])),\
                           'constant')

    
    
    
    f_img = np.fft.fft2(bin_img)
    f_obj = np.fft.fft2(bin_obj)

    f_res = np.multiply(f_img, np.conj(f_obj))

    res = np.fft.ifft2(f_res)
    
    shift = np.unravel_index(np.argmax(np.real(res)), res.shape)


    return shift
    

dom = cv2.imread('bin_images\dom.png',0)
okno = cv2.imread('bin_images\okno.png',0)

print(find_shift(dom, okno))


##f = np.fft.fft2(img)
##fshift = np.fft.fftshift(f)
##magnitude_spectrum = 20*np.log(np.abs(fshift))
##
##plt.subplot(121),plt.imshow(img, cmap = 'gray')
##plt.title('Input Image'), plt.xticks([]), plt.yticks([])
##plt.subplot(122),plt.imshow(magnitude_spectrum, cmap = 'gray')
##plt.title('Magnitude Spectrum'), plt.xticks([]), plt.yticks([])
##plt.show()
