from langgraph.graph import StateGraph

from Nodes import *
from Utils import *


if __name__ == '__main__':
    # todo: move to grapher_node
    if os.path.exists('ragGraph/input/input.txt'):
        os.remove('ragGraph/input/input.txt')

    # initialize graph
    # noinspection PyTypeChecker
    graph = StateGraph(BasicAgentState)

    # set entrypoint
    graph.set_entry_point("summarization")

    # add nodes
    # noinspection PyTypeChecker
    # graph.add_node("initializer", initializer_node)
    graph.add_node("grapher", grapher_node)
    # noinspection PyTypeChecker
    graph.add_node("summarization", file_processor_node)
    # noinspection PyTypeChecker
    graph.add_node("stub_ideas", stub_ideas_node)
    # noinspection PyTypeChecker
    graph.add_node("test_writer", test_writer_node)
    # noinspection PyTypeChecker
    graph.add_node("validator", validator_node)
    # noinspection PyTypeChecker
    graph.add_node("test_saver", test_saver_node)
    # noinspection PyTypeChecker
    graph.add_node("save_all_tests", save_all_tests_node)
    # noinspection PyTypeChecker
    graph.add_node("test_rewriter", test_rewriter_node)

    # add concrete edges
    graph.add_edge("summarization", "grapher")
    graph.add_edge("grapher", "stub_ideas")
    graph.add_edge("stub_ideas", "test_writer")
    graph.add_edge("test_writer", "validator")
    graph.add_edge("test_rewriter", "validator")

    # add conditional edges
    graph.add_conditional_edges(
        "validator",
        validator_accepted,
        {
            "true": "test_saver",
            "false": "test_rewriter"
        }
    )

    graph.add_conditional_edges(
        "test_saver",
        generated_for_all_classes,
        {
            "false": "stub_ideas",
            "true": "save_all_tests"
        }
    )

    # add finish node
    graph.set_finish_point("save_all_tests")

    app = graph.compile()

    # show_graph_from_ascii(app)
    # show_graph_from_app(app)

    # noinspection PyTypeChecker
    results = app.invoke(
        dict(ctx="",
             stub_ideas=[],
             tests=[],
             all_tests=[],
             validator_suggestions=[],
             validator_accepted=False,
             validations_steps_taken=0),
        {"recursion_limit": 1000})
    print(results['result'])
