import numpy as np
import cv2

def phase(x):
    n = abs(x)
    n[n==0]=1
    return x/n


a0 = cv2.imread('dom2.png',0)
b0 = cv2.imread('okno3.png',0) # 'okno3.png'

a = 255 - a0 
b = np.zeros(a.shape)
b[0:b0.shape[0],0:b0.shape[1]] = 255 - b0 

fa = np.fft.fft2(a)
fb = np.fft.fft2(b)
fd = np.multiply(fa,np.conj(fb))
d = np.fft.ifft2(fd)

shift = np.unravel_index(np.argmax(np.real(d)), d.shape)
print(shift)

maximum = np.max(np.real(d))
expected = np.sum(b**2)
print(maximum,expected)
error = np.abs(np.real(d)-expected)
shift2 = np.unravel_index(np.argmin(error), d.shape)
print(shift2)

print("phase correlation")
fb3 = np.conj(fb)
fa3 = phase(fa)
fb3 = phase(fb3)
fd3 = np.multiply(fa3,fb3)
d3 = np.fft.ifft2(fd3)

shift3 = np.unravel_index(np.argmax(np.real(d3)), d3.shape)
print(shift3)

res = cv2.cvtColor(a0,cv2.COLOR_GRAY2RGB)
cv2.rectangle(res,(shift[1],shift[0]),(shift[1]+b0.shape[1],shift[0]+b0.shape[0]),(0,0,255),2)
cv2.rectangle(res,(shift2[1],shift2[0]),(shift2[1]+b0.shape[1],shift2[0]+b0.shape[0]),(0,255,0),2)
cv2.rectangle(res,(shift3[1],shift3[0]),(shift3[1]+b0.shape[1],shift3[0]+b0.shape[0]),(255,0,0),2)

cv2.imwrite("found2.png", res)
cv2.imshow('found',res)
cv2.imshow('sought',b0)
cv2.waitKey(0)

cv2.destroyAllWindows()

