import torch
from torch.nn import Conv2d, MaxPool2d, ConvTranspose2d
from torch.nn.functional import relu

class UNet(torch.nn.Module):

    def __init__(self):
        super(UNet, self).__init__()
        
        self.conv1 = Conv2d(in_channels = 1, out_channels = 64,
                            kernel_size = 3, padding = 1)
        self.conv2 = Conv2d(in_channels = 64, out_channels = 64,
                            kernel_size = 3, padding = 1)

        self.maxpool1 = MaxPool2d(2)
        self.conv3 = Conv2d(in_channels = 64, out_channels = 128,
                            kernel_size = 3, padding = 1)
        self.conv4 = Conv2d(in_channels = 128, out_channels = 128,
                            kernel_size = 3, padding = 1)
        
        self.maxpool2 = MaxPool2d(2)
        self.conv5 = Conv2d(in_channels = 128, out_channels = 256,
                            kernel_size = 3, padding = 1)
        self.conv6 = Conv2d(in_channels = 256, out_channels = 256,
                            kernel_size = 3, padding = 1)

        self.maxpool3 = MaxPool2d(2)
        self.conv7 = Conv2d(in_channels = 256, out_channels = 512,
                            kernel_size = 3, padding = 1)
        self.conv8 = Conv2d(in_channels = 512, out_channels = 512,
                            kernel_size = 3, padding = 1)

        self.maxpool4 = MaxPool2d(2)
        self.conv9 = Conv2d(in_channels = 512, out_channels = 1024,
                            kernel_size = 3, padding = 1)
        self.conv10 = Conv2d(in_channels = 1024, out_channels = 512,
                             kernel_size = 3, padding = 1)

        self.upconv1 = ConvTranspose2d(in_channels = 512, out_channels = 512,
                                       kernel_size = 2, stride = 2)
        self.conv11 = Conv2d(in_channels = 1024, out_channels = 512,
                             kernel_size = 3, padding = 1)
        self.conv12 = Conv2d(in_channels = 512, out_channels = 256,
                             kernel_size = 3, padding = 1)
        

        self.upconv2 = ConvTranspose2d(in_channels = 256, out_channels = 256,
                                       kernel_size = 2, stride = 2)
        self.conv13 = Conv2d(in_channels = 512, out_channels = 256,
                             kernel_size = 3, padding = 1)
        self.conv14 = Conv2d(in_channels = 256, out_channels = 128,
                             kernel_size = 3, padding = 1)        
    
        self.upconv3 = ConvTranspose2d(in_channels = 128, out_channels = 128,
                                       kernel_size = 2, stride = 2)
        self.conv15 = Conv2d(in_channels = 256, out_channels = 128,
                             kernel_size = 3, padding = 1)
        self.conv16 = Conv2d(in_channels = 128, out_channels = 64,
                             kernel_size = 3, padding = 1)

        self.upconv4 = ConvTranspose2d(in_channels = 64, out_channels = 64,
                                       kernel_size = 2, stride = 2)
        self.conv17 = Conv2d(in_channels = 128, out_channels = 64,
                             kernel_size = 3, padding = 1)
        self.conv18 = Conv2d(in_channels = 64, out_channels = 64,
                             kernel_size = 3, padding = 1)
        self.conv19 = Conv2d(in_channels = 64, out_channels = 1,
                             kernel_size = 1, padding = 0)
        

    def forward(self, x):
        x = self.conv1(x)
        x = relu(x)
        x = self.conv2(x)
        x = relu(x)
        x1 = torch.clone(x)

        x = self.maxpool1(x)
        x = self.conv3(x)
        x = relu(x)
        x = self.conv4(x)
        x = relu(x)
        x2 = torch.clone(x)

        x = self.maxpool2(x)
        x = self.conv5(x)
        x = relu(x)
        x = self.conv6(x)
        x = relu(x)
        x3 = torch.clone(x)

        x = self.maxpool3(x)
        x = self.conv7(x)
        x = relu(x)
        x = self.conv8(x)
        x = relu(x)
        x4 = torch.clone(x)

        x = self.maxpool4(x)
        x = self.conv9(x)
        x = relu(x)
        x = self.conv10(x)
        x = relu(x)

        x = torch.cat((x4, self.upconv1(x)), dim = 1)
        x = self.conv11(x)
        x = relu(x)
        x = self.conv12(x)
        x = relu(x)

        x = torch.cat((x3, self.upconv2(x)), dim = 1)
        x = self.conv13(x)
        x = relu(x)
        x = self.conv14(x)
        x = relu(x)

        x = torch.cat((x2, self.upconv3(x)), dim = 1)
        x = self.conv15(x)
        x = relu(x)
        x = self.conv16(x)
        x = relu(x)

        x = torch.cat((x1, self.upconv4(x)), dim = 1)
        x = self.conv17(x)
        x = relu(x)
        x = self.conv18(x)
        x = relu(x)
        
        x = self.conv19(x)
        x=x.reshape((-1,256,256))
        x = torch.sigmoid(x)
        return x      






