from langchain_core.runnables.graph import MermaidDrawMethod
from langgraph.graph.state import CompiledStateGraph
from PIL import Image
from io import BytesIO
import javalang

from BasicAgentState import BasicAgentState
from langchain_core.documents.base import Document


def show_graph_from_app(app: CompiledStateGraph):
    image = Image.open(BytesIO(app.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API)))
    image.show()

def show_graph_from_ascii(app: CompiledStateGraph):
    app.get_graph().print_ascii()

def validator_accepted(state: BasicAgentState) -> str:
    if state['validations_steps_taken'] > 2 or state['validator_suggestions'].count('valid') == len(state['validator_suggestions']):
        return "true"
    return "false"

def generated_for_all_classes(state: BasicAgentState) -> str:
    """
    Function to check if ctx in state is empty.
    """
    return "true" if len(state['class_names']) == 0 else "false"

def write_log(file_name: str, content: str):
    with open(f'node_logs/{file_name}.txt', 'a', encoding='UTF-8') as file:
        file.write(content + '\n\n==============================\n\n')

def create_document_from_text(text: str, class_name: str, type: str) -> Document:
    return Document(
        page_content=text,
        metadata={
            "language": "Java",
            "class_name": class_name,
            "type": type
        }
    )

def extract_ast_structure(java_source: str):
    """
    Returns a structured representation of:
      - classes
      - methods
      - fields
      - inheritance
      - method signatures
      - method calls (best-effort)
    """
    tree = javalang.parse.parse(java_source)

    classes = []

    for path, node in tree:
        if isinstance(node, javalang.tree.ClassDeclaration):
            class_info = {
                "name": node.name,
                "extends": node.extends.name if node.extends else None,
                "implements": [impl.name for impl in node.implements] if node.implements else [],
                "methods": [],
                "fields": [],
            }

            # Extract fields
            for field in node.fields:
                for declarator in field.declarators:
                    class_info["fields"].append({
                        "name": declarator.name,
                        "type": getattr(field.type, 'name', None)
                    })

            # Extract methods
            for method in node.methods:
                method_info = {
                    "name": method.name,
                    "return_type": getattr(method.return_type, "name", None),
                    "parameters": [
                        {"name": p.name, "type": getattr(p.type, "name", None)}
                        for p in method.parameters
                    ],
                    "modifiers": list(method.modifiers),
                    "calls": []
                }

                # try to detect method calls inside method body
                if method.body:
                    for inner_path, inner_node in method:
                        if isinstance(inner_node, javalang.tree.MethodInvocation):
                            method_info["calls"].append(inner_node.member)

                class_info["methods"].append(method_info)

            classes.append(class_info)

    return classes