from model import UNet
from dataset import Dataset
from torch.utils.data import  DataLoader
from torch.nn import CrossEntropyLoss, MSELoss, NLLLoss, BCEWithLogitsLoss
import torch
import torch.optim as optim
import matplotlib.pyplot as plt


dicom_files = 'E:\\DP_data\\Prostate-MRI-US-Biopsy'
target_files = 'E:\\DP_data\\targets'

data = Dataset(dicom_files, target_files)
n = len(data)

data_train, data_test = torch.utils.data.random_split(
    data, [round(0.8*n), round(0.2*n)],
    generator=torch.Generator().manual_seed(42))

batch_size = 10
d_train = DataLoader(data_train, batch_size = batch_size, shuffle = True)
d_test = DataLoader(data_test, batch_size = batch_size, shuffle = True)

model = UNet()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


print('TRAIN')
ce_loss = CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_losses = []

for epoch in range(5):
    model.train()
    for i, batch in enumerate(d_train):
        x, y = batch  
        optimizer.zero_grad()

        out = model(x)
        loss = ce_loss(out, y)
        loss.backward()
        optimizer.step()
        #if i % 100 == 0:
        print("Loss at epoch: {} step {}: {}".format(epoch, i, loss.item()))
        train_losses.append(loss.item())

print('---DONE---')

print('TEST')
model.eval()
with torch.no_grad():
    val_losses = []
    correct = 0
    for i, batch in enumerate(d_test):
        x, y = batch  

        out = model(x)
        loss = ce_loss(out, y)
        #acc = torch.sum(torch.argmax(out, dim=-1) == y)
        #correct += acc.item()
        val_losses.append(loss.item())

    print("Val loss at epoch {}: {}".format(e, np.mean(val_losses)))
    print("Val acc at epoch {}: {}".format(e, correct / 10000))


plt.figure(figsize=(10,5))
plt.title('Training and Validation Loss')
plt.plot(val_losses, label='val')
plt.plot(train_losses, label='train')
plt.xlabel('iterations')
plt.ylabel('Loss')
plt.legend()
plt.show()
