from Utils import *
from prompts import *
from singletons.TestMerger import TestMerger
from singletons.models import *
from singletons.Extractor import extractor_singleton
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
import faiss

import os
import subprocess
import re


def file_processor_node(state: BasicAgentState) -> BasicAgentState:
    state['class_names'] = []
    # system_prompt = SystemMessage(content=create_summarization_system_prompt())
    #
    pattern = re.compile(
        r'''
        ^\s*                                                    # optional leading whitespace
        (?:(?:public|protected|private)\s+)?                    # optional access modifier
        (?:(?:abstract|final|static|sealed|non-sealed)\s+)*     # optional non-access modifiers
        (class|interface|enum|record)\s+                        # the type keyword
        (?P<name>[A-Z][A-Za-z0-9_]*)                            # capture: type name
        (?:\s+extends\s+[A-Za-z0-9_\.]+)?                       # optional "extends" clause
        (?:\s+implements\s+[A-Za-z0-9_\.,\s]+)?                 # optional "implements" clause
        (?:\s+permits\s+[A-Za-z0-9_\.,\s]+)?                    # optional "permits" (sealed classes)
        \s*\{                                                   # opening brace
        ''',
        re.VERBOSE | re.MULTILINE
    )

    batch_size = 1
    # vector_store = FAISS(
    #     embedding_function=embeddings,
    #     index=faiss.IndexFlatL2(len(embeddings.embed_query("hello world"))),
    #     docstore=InMemoryDocstore(),
    #     index_to_docstore_id={},
    # )
    #
    while True:
        # read in the next file containing source code
        file_content = extractor_singleton.extract(batch_size)
        if len(file_content) == 0:
            break

        file_content = file_content[0]
        search_result = pattern.search(file_content)
        # object_type = ""
        # class_name = ""
        if search_result is not None:
            class_name = search_result.group("name")
            state['class_names'].append(class_name)
    #         object_type = search_result.group(1)
    #
    #     # create summery of the source code
    #     human_prompt = HumanMessage(content=create_summarization_human_prompt(file_content))
    #     all_prompts = [system_prompt, human_prompt]
    #     result = summarization_model.invoke(all_prompts)
    #     ast = extract_ast_structure(file_content)
    #     # save generated summary to file
    #     # todo: is the path always unique?
    #     with open(f'ragGraph/input/{class_name}_{object_type}.txt', 'w', encoding='UTF-8') as file:
    #         file.write(f'{result.content}\n\n## Abstract syntax tree of {class_name} {object_type}\n{ast}')
    #
    #     # todo: Test if RAG is sufficient, if not, replace with Redis
    #     # create vector store for traditional RAG
    #     vector_store.add_documents([create_document_from_text(file_content, class_name, object_type)])
    #     vector_store.add_documents([create_document_from_text(result.content, class_name, object_type)])
    #
    #     write_log('file_processor', f'objet:{class_name}_{object_type}\n\tSummary:{result.content}\n\tAST:{ast}')
    #
    # # save vector store for later
    # vector_store.save_local("vector_store_index")

    print("=== Summary Done ===")
    return state

def grapher_node(state: BasicAgentState) -> BasicAgentState:
    # if not os.path.isdir("ragGraph"):
    #     os.makedirs("ragGraph")
    #     args = "python -m graphrag init --root ./ragGraph".split()
    #     subprocess.run(args, shell=True)
    #
    # args = "python -m graphrag index --root ./ragGraph".split()
    # subprocess.run(args, shell=True)

    # todo: add check for successful graph indexing
    print("=== Graph created ===")
    return state

# todo: relying on the llm actually including the <test> token is not a good idea -
#  maybe do it iteratively, while including previously generated tests for ctx
def stub_ideas_node(state: BasicAgentState) -> BasicAgentState:
    state['stub_ideas'] = []
    # get next class to generate tests for
    class_name = state['class_names'].pop(0)
    state['current_class_name'] = class_name

    # using GraphRag, get all entities related to the class under test
    args = f"python -m graphrag query --root ./ragGraph -m global -q \"Please tell me, which entities are related to the {class_name} entity. The {class_name} entity itself should be included.\"".split() # todo: improve query
    result = str(subprocess.run(args, shell=True, capture_output=True).stdout)

    vector_store = FAISS.load_local("vector_store_index", embeddings, allow_dangerous_deserialization=True)
    retriever = vector_store.as_retriever(search_type="similarity", k=3)
    examples = retriever.invoke(f"Please tell me about these entities: {result}") #todo: improve query
    context = "\n\n---\n\n".join([doc.page_content for doc in examples])

    state['ctx'] = context
    # print(f'Context generated in stub_ideas_node:\n {context}')

    all_prompts = [HumanMessage(create_stub_ideas_prompt(class_name, context))]
    result = stub_ideas_model.invoke(all_prompts)

    for test in result.content.split('</test>'):
        state['stub_ideas'].append(test[test.find('<test>'):])

    write_log('stub_ideas_log', result.content)
    print("=== Stubbing Ideas Done ===")

    return state


def test_writer_node(state: BasicAgentState) -> BasicAgentState:
    state['tests'].clear()

    system_prompt = SystemMessage(content=create_test_writer_node_system_prompt())

    for scenario in state['stub_ideas']:
        human_prompt = HumanMessage(content=create_test_writer_node_human_prompt(scenario, state['ctx']))

        all_prompts = [system_prompt, human_prompt]
        result = test_writer_model.invoke(all_prompts)
        state['tests'].append(result.content)

    write_log('test_writer_log', '\n'.join(state['tests']))
    print("=== Writing tests Done ===")

    return state


def validator_node(state: BasicAgentState) -> BasicAgentState:
    state['validator_suggestions'].clear()

    if validator_accepted(state):
        system_prompt = SystemMessage(content=create_validator_node_system_prompt())

        for test in state['tests']:
            human_prompt = HumanMessage(content=create_validator_node_human_prompt(test))

            all_prompts = [system_prompt, human_prompt]
            result = validator_model.invoke(all_prompts)
            state['validator_suggestions'].append(result.content)

            write_log('validator_log', f'Test:\n{test}\nValidator:{result.content}')

        state['validations_steps_taken'] += 1
        print("=== Validating Tests Done ===")
        # print('\n'.join(state['validator_suggestions']))

    return state


def test_rewriter_node(state: BasicAgentState) -> BasicAgentState:
    system_prompt = SystemMessage(content=create_test_rewriter_node_system_prompt())
    old_test_amount = len(state['tests'])
    for i in range(old_test_amount):
        test_to_rewrite = state['tests'].pop()
        validator_suggestions = state['validator_suggestions'].pop()
        human_prompt = HumanMessage(content=create_test_rewriter_node_human_prompt(test_to_rewrite, validator_suggestions))

        all_prompts = [system_prompt, human_prompt]
        result = test_writer_model.invoke(all_prompts)

        write_log('test_rewriter_log', f'Original test:\n{test_rewriter_model}\nRewritten test:{result.content}')
        state['tests'].append(result.content)

    print("=== Rewriting tests Done ===")

    return state


def test_saver_node(state: BasicAgentState) -> BasicAgentState:
    for test in state['tests']:
        state['all_tests'].append(test)

    state['stub_ideas'].clear()
    state['tests'].clear()
    state['validator_suggestions'].clear()
    state['validator_accepted'] = False
    state['validations_steps_taken'] = 0
    return state


def save_all_tests_node(state: BasicAgentState) -> BasicAgentState:
    with open('out_for_JdbcConnection.txt', 'w', encoding='utf-8') as file:
        try:
            file.write("\n\n".join(state['all_tests']))
            print("Tests have been saved")
        except UnicodeEncodeError:
            print("Error while writing to file with content:\n\n".join(state['all_tests']))
    tm = TestMerger("out_for_JdbcConnection.txt")
    while len(tm.lines) > 0:
        tm.state = tm.state()
    print(tm.save_to_file("outputs/converted_output.txt"))
    return state