from OneCellNN import OneCellNN
from DataSetNeedleman import DataSetNeedleman
from torch.utils.data import  DataLoader
import torch
from torch.utils.data import random_split

dtst=DataSetNeedleman('NeedlemaFixed',num_files=1000)
lngth=len(dtst)
dataset_train,dataset_eval=random_split(dtst,(round(0.8*lngth),round(0.2*lngth)),generator=torch.Generator().manual_seed(42))

dataloader_train = DataLoader(dataset_train, batch_size=100, shuffle=True)

model=OneCellNN()
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for e in range(2):
  model.train()
  for i, batch in enumerate(dataloader_train):  
    x, y = batch  
    optimizer.zero_grad()
    output = model(x)
    loss = loss_func(output, y) 
    loss.backward()
    optimizer.step()
    if i % 100 == 0:
      print("Loss at epoch: {} step {}: {}".format(e, i, loss.item()))
