import numpy as np

def omega(base):
    power = np.linspace(0,base-1, base)
    return np.exp(-2*np.pi*1j*power/base)

def fft(x):

    N = len(x)

    if N == 1:
        return x
    
    even = fft(x[::2])
    odd = fft(x[1::2])
    

    return np.append(even,even) + omega(N)*np.append(odd,odd)


def dft(inputs):

    N = len(inputs)
    
    res = np.zeros(N,dtype=np.complex_)

    
    n = np.linspace(0,N-1,N)
    
    for k in range(N):

        omega = np.exp(-2j*np.pi*n*k/N)
        res[k] = np.sum(inputs * omega)

    return res



if __name__ == '__main__':
    from time import time

    print('================  TEST  ====================')
    for n in range(3,20):
        x = np.cos(np.linspace(0,2*np.pi,2**n))
        print(f'Number od samples = {2**n}')
        s1 = time()
        fft(x)
        t1 = time() - s1
        
        s2 = time()
        dft(x)
        t2 = time() - s2

        print(f"fft -> {t1} sec\ndft -> {t2} sec")

        print('--------------------------------------------\n')
