import torch
import os
import json
import numpy as np
import cv2
from file_loader import DicomLoader

class Dataset(torch.utils.data.Dataset):
  
  def __init__(self, dicom_rootpath, target_rootpath):
    self.X = []
    self.Y = []

    self.dicom = DicomLoader(dicom_rootpath)
    self.target_rootpath = target_rootpath
    targets_dir = list(os.listdir(target_rootpath))
    counter = -1
    num_of_patients = len(targets_dir)
    for f in targets_dir:
        counter += 1
        print(f'loading patients dataset {counter/num_of_patients*100:.2f}%')
        txt = f.split('.')
        idx = txt[0].index('_')
        idx = int(txt[0][idx + 1:])

        #dicom = self.dicom.read_one_slice(patient_id, slice_id + 1)
        dicom = self.dicom.read_one_3D_image(idx)
        target = np.load(self.target_rootpath + '\\target_' +
                         str(idx) + '.npz')['x']  
        k = np.random.rand()
        save = True if k < 0.7 else False
      
        for i in range(dicom.size):
            if (np.sum(target[i])) == 0:
              if not save:
                continue
            self.X.append(torch.from_numpy(
              np.array(cv2.resize(dicom[i].pixel_array, (256, 256)),
                      dtype='float32')).
                      reshape(1, 256, 256))
            self.Y.append(torch.from_numpy(
              np.array(target[i], dtype='float32')))

  def __len__(self):
    return len(self.X)

  def __getitem__(self, idx):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    return self.X[idx].to(device), self.Y[idx].to(device)
