import pydicom
from pydicom.data import get_testdata_files
import os
import numpy as np
import cv2


class DicomLoader:

    def __init__(self, root_path):
        self.root_path = root_path


    def read_one_slice_from_path(self, path, img_id, resize=None):
        file = f'1-{img_id:02d}.dcm'
        path2 = os.path.join(path, file)
        ds = pydicom.dcmread(path2)
        n = len(ds.pixel_array)*len(ds.pixel_array[0])
        return np.array(ds.pixel_array)
    

    def read_one_slice(self, patient_id, img_id, resize=None):
        patient_path = f'Prostate-MRI-US-Biopsy-{patient_id:04d}'
        file = f'1-{img_id:02d}.dcm'
        path2 = os.path.join(self.root_path, patient_path)
        for folder in os.listdir(path2):
            path3 = os.path.join(path2, folder)
            if 'WO' in path3:
                while len(os.listdir(path3)) == 1:
                    dirs = os.listdir(path3)
                    path3 = os.path.join(path3, dirs[0])
                path_final = os.path.join(path3, file)
                ds = pydicom.dcmread(path_final)
                ###
##                shape = ds.pixel_array.shape
##                image_2d = ds.pixel_array.astype(float)
##                image_2d_scaled = (np.maximum(image_2d,0) / image_2d.max()) * 255.0
##                image_2d_scaled = np.uint8(image_2d_scaled)
##                return image_2d_scaled
                ###
                #return np.array(ds.pixel_array)
                ds = ds.pixel_array
                if ds.shape != (256, 256):
                    ds = cv2.resize(ds, (256, 256))
                return np.array(ds, dtype='float32')
                #return np.array(ds)
        raise FileNotFoundError


    def read_one_3D_image(self, patient_id, resize=None):
        patient_path = f'Prostate-MRI-US-Biopsy-{patient_id:04d}'
        path2 = os.path.join(self.root_path, patient_path)
        for folder in os.listdir(path2):
            path3 = os.path.join(path2, folder)
            if 'WO' in path3:
                while len(os.listdir(path3)) == 1:
                    dirs = os.listdir(path3)
                    path3 = os.path.join(path3, dirs[0])
                img = []
                for img_slice in os.listdir(path3):
                    path_final = os.path.join(path3, img_slice)
                    ds = pydicom.dcmread(path_final)
                    #img.append(ds.pixel_array)
                    img.append(ds)
                return np.array(img)
        raise FileNotFoundError
        

