import os.path
import re

class TestMerger:
    def __init__(self, path, should_log = False):
        self.imports: list[str] = []
        self.tests: list[str] = []
        self.mocks: list[str] = []
        self.beforeEach: list[str] = []
        self.afterEach: list[str] = []
        self.classVar: list[str] = []
        self.methods: list[str] = []
        self.classAnnotations: list[str] = []
        self.state = self.initial_state
        self.lines: list[str] = []
        self.previous_comments: list[str] = []
        self.stack_trace: list[str] = []
        self.should_log = should_log

        self.mock_annotation_matcher = re.compile(r'\s*@.*Mock')
        self.afterEach_annotation_matcher = re.compile(r'\s*@.*After(Each)?')
        self.beforeEach_annotation_matcher = re.compile(r'\s*@.*Before(Each)?')
        self.test_annotation_matcher = re.compile(r'\s*@.*Test?')
        self.method_pattern_matcher = re.compile(r'\s*(private|public)?.*\(.*\).*\{')
        self.class_variable_patter_matcher = re.compile(r'\s*(private|public)?.+(;|\{|\n\s*\.)')
        self.single_line_comment_patter_matcher = re.compile(r'(\s*//.*)+')
        self.multi_line_comment_start_patter_matcher = re.compile(r'(\s*/\*\*.*)+')
        self.multi_line_comment_end_patter_matcher = re.compile(r'(\s*\*/.*)+')

        with open(path, 'r', encoding='utf-8') as file:
            lines = file.readlines()
            for line in lines:
                if len(line.strip()) > 0:
                    self.lines.append(line[:-1])

    def save_to_file(self, path):
        delimiter = '\n'
        doubleDelimiter = '\n\n'
        imports = f"{delimiter}".join(set(self.imports))
        classAnnotations = f"{delimiter}".join(self.classAnnotations)
        mocks = f"{delimiter}".join(self.mocks)
        classVars = f"{delimiter}".join(self.classVar)
        beforEach = f"{doubleDelimiter}".join(self.beforeEach)
        afterEach = f"{doubleDelimiter}".join(self.afterEach)
        tests = f"{doubleDelimiter}".join(self.tests)
        methods = f"{doubleDelimiter}".join(self.methods)

        with open(path, 'w', encoding='utf-8') as file:
            template = f"""{imports}

{classAnnotations}
public class GeneratedTests() {'{'}
{f"{doubleDelimiter}{classVars}" if len(classVars) > 0 else ""}
{f"{doubleDelimiter}{mocks}" if len(mocks) > 0 else ""}
{f"{doubleDelimiter}{beforEach}" if len(beforEach) > 0 else ""}
{f"{doubleDelimiter}{afterEach}" if len(afterEach) > 0 else ""}
{f"{doubleDelimiter}{tests}" if len(tests) > 0 else ""}
{f"{doubleDelimiter}{methods}" if len(methods) > 0 else ""}
{'}'}
"""
            while template.count("\n\n\n") > 0:
                template = template.replace("\n\n\n", "\n\n")
            file.write(template)
            return template


    def initial_state(self):
        if len(self.lines) == 0:
            return None
        if self.should_log:
            self.stack_trace.append(f"initial_state: {self.lines[0]}")
        while len(self.lines) > 0 and not self.lines[0].__contains__("import "):
            self.lines.pop(0)
            # return self.initial_state
        if len(self.lines) == 0:
            return None
        return self.import_state

    def import_state(self):
        if self.should_log:
            self.stack_trace.append(f"import_state: {self.lines[0]}")
        while self.lines[0].__contains__('import '):
            self.imports.append(self.lines.pop(0))

        self.save_comments()
        self.previous_comments.clear()

        if self.lines[0].__contains__('@'):
            return self.class_annotation_state
        if self.lines[0].__contains__("class"):
            return self.class_declaration_state
        return IllegalStateMachineException(f"Could not proceed from import state with token {self.lines[0]}")

    def class_annotation_state(self):
        if self.should_log:
            self.stack_trace.append(f"class_annotation_state: {self.lines[0]}")
        while self.lines[0].__contains__("@"):
            self.classAnnotations.append(self.lines.pop(0))
        return self.class_declaration_state

    def class_declaration_state(self):
        if self.should_log:
            self.stack_trace.append(f"class_declaration_state: {self.lines[0]}")
        self.lines.pop(0)
        return self.determine_inner_class_state()

    def class_variable_state(self):
        if self.should_log:
            self.stack_trace.append(f"class_variable_state: {self.lines[0]}")
        self.classVar.append('\n'.join(self.previous_comments))
        self.previous_comments.clear()
        if self.lines[0].__contains__(";"):
            self.classVar.append(self.lines.pop(0))

        if self.lines[0].__contains__("{"):
            self.classVar.append('\n'.join(self.extract_block()))

        return self.determine_inner_class_state()

    def method_state(self):
        if self.should_log:
            self.stack_trace.append(f"method_state: {self.lines[0]}")
        self.methods.append('\n'.join(self.extract_block()))
        return self.determine_inner_class_state()

## ======== Annotation state methods ========

    def mock_state(self):
        if self.should_log:
            self.stack_trace.append(f"mock_state: {self.lines[0]}")
        self.mocks.append('\n'.join(self.previous_comments))
        self.previous_comments.clear()
        if not (self.lines[0].__contains__(";") or self.lines[0].__contains__("{")):
            self.mocks.append(self.lines.pop(0))
        if self.lines[0].__contains__(";"):
            self.mocks.append(self.lines.pop(0))

        if self.lines[0].__contains__("{"):
            self.mocks.append('\n'.join(self.extract_block()))

        return self.determine_inner_class_state()

    def test_state(self):
        if self.should_log:
            self.stack_trace.append(f"test_state: {self.lines[0]}")
        self.tests.append('\n'.join(self.previous_comments))
        self.previous_comments.clear()
        self.tests.append(self.lines.pop(0))
        self.tests.append('\n'.join(self.extract_block()))
        return self.determine_inner_class_state()

    def afterEach_state(self):
        if self.should_log:
            self.stack_trace.append(f"afterEach_state: {self.lines[0]}")
        self.afterEach.append('\n'.join(self.previous_comments))
        self.previous_comments.clear()
        self.afterEach.append(self.lines.pop(0))
        self.afterEach.append(' '.join(self.extract_block()))
        return self.determine_inner_class_state()

    def beforeEach_state(self):
        if self.should_log:
            self.stack_trace.append(f"beforeEach_state: {self.lines[0]}")
        self.beforeEach.append('\n'.join(self.previous_comments))
        self.previous_comments.clear()
        self.beforeEach.append(self.lines.pop(0))
        self.beforeEach.append(' '.join(self.extract_block()))
        return self.determine_inner_class_state()

## ======== Util methods ========

    def extract_block(self) -> list[str]:
        open_blocks = 0
        tmp = []
        while True:
            current_line = self.lines.pop(0)
            tmp.append(current_line)
            if '{' in current_line:
                open_blocks += 1
            if '}' in current_line:
                open_blocks -= 1
            if open_blocks == 0:
                return tmp

    def determine_inner_class_state(self):
        self.save_comments()
        if self.mock_annotation_matcher.match(self.lines[0]):
            return self.mock_state
        if self.afterEach_annotation_matcher.match(self.lines[0]):
            return self.afterEach_state
        if self.beforeEach_annotation_matcher.match(self.lines[0]):
            return self.beforeEach_state
        if self.test_annotation_matcher.match(self.lines[0]):
            return self.test_state
        if self.method_pattern_matcher.match(self.lines[0]):
            return self.method_state
        if self.class_variable_patter_matcher.match(self.lines[0]):
            return self.class_variable_state
        return self.initial_state()

    def save_comments(self):
        while True:
            if self.single_line_comment_patter_matcher.match(self.lines[0]):
                self.previous_comments.append(self.lines.pop(0))
            elif self.multi_line_comment_start_patter_matcher.match(self.lines[0]):
                # while self.multi_line_comment_end_patter_matcher.match(self.lines[0]) is None:
                while not self.lines[0].__contains__('*/'):
                    self.previous_comments.append(self.lines.pop(0))
                self.previous_comments.append(self.lines.pop(0))
            else:
                break



class IllegalStateMachineException(Exception):
    def __init__(self, message: str):
        super().__init__(message)


if __name__ == '__main__':
    should_log = True
    tm = TestMerger('../out_for_JdbcConnection.txt', should_log)
    while len(tm.lines) > 0:
        tm.state = tm.state()
        if tm.state is None:
            tm.save_to_file('converted_output.txt')
    # print('\n'.join(tm.stack_trace))
        # try:
        #     tm.state = tm.state()
        #     if tm.state is None:
        #         tm.save_to_file('converted output.txt')
        #         print("saved")
        #         break
        # except IllegalStateMachineException:
        #     print("Illegal state machine exception")
        #     if should_log:
        #         print('\n'.join(tm.stack_trace[-10:]))
        #     break


