import numpy as np 
import cv2



class LogPolarTemplateMatching:

    def __init__(self):
        pass

    def find(self, im0_path, im1_path):
        ANGLE = 0
        SCALE = 2

        canny_tresh1 = 100
        canny_tresh2 = 200
        
        im0 = cv2.imread(im0_path)
        im0G = cv2.cvtColor(im0, cv2.COLOR_BGR2GRAY)
        im0GC = cv2.Canny(im0G, canny_tresh1, canny_tresh2) 
        # cv2.imshow('src', im0GC)
        im0GC = im0GC.astype(np.float32) / 255
        
        im1 = cv2.imread(im1_path)
        im1G = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY)
        tempC = cv2.Canny(im1G, canny_tresh1, canny_tresh2)
        im1GCR = self.rotateScaleImage(tempC, ANGLE, SCALE)

        im1GC = np.zeros(im0GC.shape, dtype=np.float32)
        im1GC[0:im1GCR.shape[0],0:im1GCR.shape[1]] = im1GCR / 255

        # cv2.imshow('aa', im1GC)
        # cv2.waitKey(2000)

        F0 = np.fft.fftshift(np.fft.fft2(im0GC))
        F1 = np.fft.fftshift(np.fft.fft2(im1GC))

        highpass = self.getHighpass(F0.shape)
        M0 = np.abs(F0) * highpass
        M1 = np.abs(F1) * highpass

        M0lp, log_base0 = self.logPolarTransform(M0)
        M1lp, log_base1 = self.logPolarTransform(M1)

        t0, t1 = self.phaseCorr(M0lp, M1lp)
        angle = 180/ M0lp.shape[0] * t0
        scale = log_base0 ** t1

        if angle < -90.0:
            angle += 180.0
        elif angle > 90.0:
            angle -= 180.0

        print(f'rows: {M0lp.shape[0]}\nlog base: {log_base0}')
        print(f'phase corr returns: {(t0,t1)} -> angle: {angle}, scale: {scale}')

        im2G = self.rotateScaleImage(im1GCR, angle, 1/scale)
        im2GC = np.zeros(im0GC.shape, dtype=np.float32)
        im2GC[0:im2G.shape[0],0:im2G.shape[1]] = cv2.Canny(im2G, canny_tresh1, canny_tresh2) / 255

        # cv2.imshow('bb', im1GCR)
        # cv2.imshow('cc', im2G)

        shift = self.phaseCorr(im0GC, im2GC)
        # b0 = im2G
        # res = im0.copy()
        # point1 = (shift[0],shift[1])
        # point2 = (shift[0] + im2G.shape[0],shift[1] + im2G.shape[1])
        # print('points', point1, point2)
        # cv2.rectangle(res,point1,point2,(0,0,255),20)

        # cv2.imshow('aa',res)
        # cv2.waitKey(5000)

        # print(angle, scale, shift)
        # print(res.shape,im2G.shape)

        return angle, scale, shift


    def rotateScaleImage(self, img, angle, scale):
        centre = (img.shape[0]//2, img.shape[1]//2)
        rotation_matrix = cv2.getRotationMatrix2D(centre, angle, scale)

        radii = np.radians(angle)
        sin = np.sin(radii)
        cos = np.cos(radii)

        height = int(scale*(img.shape[0]*np.abs(sin) + img.shape[1]*np.abs(cos)) ) 
        width = int(scale*(img.shape[0]*np.abs(cos) + img.shape[1]*np.abs(sin)) ) 

        return cv2.warpAffine(img, rotation_matrix, (height, width))

    def getHighpass(self, shape):
        
        rows, cols = shape

        row_filter = np.cos(np.linspace(-np.pi/2, np.pi/2, rows))[np.newaxis]
        col_filter = np.sin(np.linspace(-np.pi/2, np.pi/2, cols))[np.newaxis]

        x = row_filter.T @ col_filter

        return (1.0 -x) * (2.0 -x)

    def logPolarTransform(self, src):

        rows, cols = src.shape
        centre = rows/2, cols/2

        angles, radii = rows, cols

        d = np.sqrt(centre[0]**2 + centre[1]**2)
        log_base = 10.0**(np.log10(d) / radii)

        theta = np.ones((1, cols), np.float32) * (-1.0)
        theta_row = np.linspace(0, np.pi, num=angles, endpoint=False,dtype=np.float32)[np.newaxis]
        theta = theta_row.T @ theta

        radius = np.ones((rows, 1), np.float32)
        radius_col = log_base ** (np.arange(start=0, stop=radii, dtype=np.float32)[np.newaxis] - 1)
        radius = radius @ radius_col

        mapy = np.cos(theta) * radius + centre[1]
        mapx = np.sin(theta) * radius + centre[0]

        mapx, mapy = mapx.astype(np.float32), mapy.astype(np.float32)

        return cv2.remap(src, mapx, mapy, 1), log_base
  
    def phaseCorr(self,a, b):

        def phase(x):
            n = abs(x)
            n[n==0]=1
            return x/n

        fa = np.fft.fft2(a)
        fb = np.fft.fft2(b)
        fb = np.conj(fb)
        fd = np.multiply(fa,fb) 
        r = np.fft.ifft2(fd)

        x0, x1 = np.unravel_index(np.argmax(np.real(r)), r.shape)

        return x0, x1

SRC1 = 'dom.png'
TEMP1 = 'lava_rimsa.png'

SRC2 = 'src.jpg'
TEMP2 = 'temp.jpg'

x = LogPolarTemplateMatching().find(SRC1, TEMP1)
print(x)