from target_data_loader import *
from file_loader import *
import trimesh
from threading import Thread
import traceback
#import multiprocessing

N_PATIENTS = 1151

fpath = 'C:\\Users\\daska\\Documents\\AIN\\MAGISTER\\Diplomovka\\Dataset\\Target Data_2019-12-05.csv'
stls_root_path = 'C:\\Users\\daska\\Documents\\AIN\\MAGISTER\\Diplomovka\\Dataset\\STLs\\STLs'

#fpath2 = 'C:\\Users\\daska\\Documents\\AIN\\MAGISTER\\Diplomovka\\Dataset\\Data_with_biopsy\\Prostate-MRI-US-Biopsy'
fpath2 = 'D:\\DP_data\\Prostate-MRI-US-Biopsy'

dest_targets = 'D:\\DP_data\\targets'
dest_prostates = 'D:\\DP_data\\prostate_targets'

class DataDoNotExist(Exception):
    pass


class PatientLoader:

    def __init__(self, target_file_path, stl_path, input_data_path, patients):
        self.tdl = TargetDataLoader(target_file_path, stl_path)
        self.dl = DicomLoader(input_data_path)
        self.patients = patients

    def read_data_for_patient(self, patient_id):
        results = self.tdl.get_for_patient(patient_id)
        if results is None:
            raise DataDoNotExist
        prostates = self.tdl.get_prostate_stl_for_patient(patient_id)
        if prostates is None:
            raise DataDoNotExist
        try:
            imgs_all = self.dl.read_3D_images(patient_id)
        except FileNotFoundError:
            raise DataDoNotExist

        data = {}
        for uid, imgs in imgs_all.items():
            target = np.zeros((imgs.size, 256, 256))
            prostate = np.zeros((imgs.size, 256, 256))

            points_xy = np.array([[(x, y) for x in range(256)] for y in range(256)])
            points_xy = points_xy.reshape((65536, -1))
                
            for i in range(imgs.size):
                ds = imgs[i]

                xo,yo,zo = ds[0x20, 0x32].value
                px, py = ds[0x28, 0x30].value
                row_x, row_y, row_z, col_x, col_y, col_z = ds[0x20, 0x37].value

                matrix = np.array([[row_x*px, col_x*py, 0, xo], \
                                [row_y*px, col_y*py, 0, yo], \
                                [row_z*px, col_z*py, 0, zo], \
                                [0, 0, 0, 1]])
                
                points = []
                for x, y in points_xy:
                    vector = np.array([x, y, 0, 1])
                    v2 = matrix @ vector
                    a, b, c = v2[0], v2[1], v2[2]
                    points.append((a, b, c))

                #targets
                for finding in results:
                    if not uid in finding:
                        continue
                    mesh = trimesh.load_mesh(finding)
                    trimesh.repair.fill_holes(mesh)
                    truth_table = mesh.contains(points)
                    target[i, :, :] += truth_table.reshape(256, 256)
                target[i, :, :] = np.minimum(1, target[i, :, :]) 

                #prostate
                prostate_finding = [x for x in prostates if uid in x]
                if len(prostate_finding) > 1:
                    raise ValueError
                mesh2 = trimesh.load_mesh(prostate_finding[0])
                trimesh.repair.fill_holes(mesh2)
                truth_table2 = mesh2.contains(points)
                prostate[i, :, :] += truth_table2.reshape(256, 256)
                prostate[i, :, :] = np.minimum(1, prostate[i, :, :]) 
            data[uid] = (target, prostate)

        return data


    def write_data(self):
        bad = []
        other = []
        print('start')
        for patient in self.patients:
            try:
                dic = self.read_data_for_patient(patient)
                for k, val in dic.items():
                    data_file_name = dest_targets + f'\\target_{patient}_{k}'
                    np.savez(data_file_name,  x = val[0])
                    data_file_name2 = dest_prostates + f'\\prostate_{patient}_{k}'
                    np.savez(data_file_name2,  x = val[1])
            except DataDoNotExist:
                bad.append(patient)
                print('data_do_not_Exist' + str(patient))
            except:
                traceback.print_exc()
                other.append(patient)
                print('something went wrong' + str(patient))
            print(patient)
        print('pocet neulozenych:', len(bad))
        print(bad)
        print('ostatne:', len(other))
        print(other)

                                    


rngs = [[]]
counter = 0
for i in range(1, N_PATIENTS + 1):
        rngs[-1].append(i)
        counter += 1
        if counter == 120:
            rngs.append([])
            counter = 0
workers = []

for i in rngs:
    dl = PatientLoader(fpath, stls_root_path, fpath2, i)
    tr = Thread(target=dl.write_data)
    print('a')
    workers.append(tr)
for t in workers:
    t.start()
    print('b')
for t in workers:
    t.join()
print('koniec')

    




