import numpy as np

a = np.array([\
    [1,2,3,4,5],\
    [1,3,5,7,9],\
    [0,2,4,6,8],\
    [5,4,3,2,1],\
    [3,3,3,2,2]\
])

ashift = (2,3) # (shift in columns, shift in raws)

b = np.roll(a,ashift,axis=(1,0))

print(b)

c = np.zeros(a.shape)
for i in range(0,5):
    for j in range(0,5):
        #c[i,j]=np.sum((a-np.roll(b,(j,i),axis=(1,0)))**2) # look for zero
        c[i,j]=np.sum(a*np.roll(b,(j,i),axis=(1,0))) # look for maximum
print(c)

cshift = np.unravel_index(np.argmax(c), c.shape)
# now cshift == ashift
print(cshift,ashift)

fa = np.fft.fft2(a)
fb = np.fft.fft2(b)
fd = np.multiply(fa,np.conj(fb))
d = np.fft.ifft2(fd)
# now c == d
print(d)

dshift = np.unravel_index(np.argmax(np.real(d)), d.shape)
# now dshift == ashift
print(dshift,ashift)
        


