package datalog_ra.parser;

import datalog_ra.base.dataStructures.*;
import datalog_ra.evaluation.Rule;
import datalog_ra.evaluation.Term;

import java.io.*;
import java.util.HashSet;
import java.util.List;
import java.util.ArrayList;

public class Parser {
    private Lexer lexer;
    private ArrayList<Rule> rules;
    private Instance facts;

    public Parser(InputStream input) {
        lexer = new Lexer(input);
        rules = new ArrayList<>();
        facts = new Instance();
    }

    public Instance getFacts() { return facts; }
    public List<Rule> getRules() { return rules; }

    /**
     * Parses input stream.
     * @throws ParseException
     */
    public void parse() throws ParseException {
        // program := { rule };
        lexer.next();
        while (!is(Token.eof)) {
            parseRule();
        }
    }

    public Rule parseSingleQuery() throws ParseException {
        // program := query;
        lexer.next();
        Rule query = parseQuery();
        expect(Token.eof);
        return query;
    }

    private void parseRuleBody(Rule rule, RuleVariables vars) throws ParseException {
        // body = | [ not ] predicate, { comma, [ not ] predicate };
        // predicate = constant, term_list
        //           | variable, (equal | not_equal), term
        while (is(Token.constant) || is(Token.variable) || is(Token.not)) {
            boolean not = false;
            if (is(Token.not)) {
                not = true;
                lexer.next();
            }
            Term t = readTerm();
            if (t.isConstant()) {
                String name = t.getTerm();
                ArrayList<Term> termList = parseTermList();
                rule.addSubgoal(name, termList, not);
                addVarsFromTermList(!not ? vars.positiveLiteralContextVariables : vars.otherVariables, termList);
            } else {
                boolean neq = expect(Token.equal, Token.not_equal) == Token.not_equal;
                Term rhs = readTerm();
                rule.addComparison(t, rhs, not != neq);
                addVar(vars.otherVariables, t);
                addVar(vars.otherVariables, rhs);
            }
            if (!is(Token.dot)) {
                expect(Token.comma);
            }
        }
    }

    private ArrayList<Term> parseTermList() throws ParseException {
        // term_list = lparen, term, { comma, term }, rparen;
        ArrayList<Term> termList = new ArrayList<>();
        expect(Token.lparen);
        expect_peek(Token.constant, Token.variable);
        while (is(Token.constant) || is(Token.variable)) {
            termList.add(readTerm());
            if (!is(Token.rparen)) {
                expect(Token.comma);
            }
        }
        expect(Token.rparen);
        return termList;
    }

    private void parseRule() throws ParseException {
        // rule = constant, term_list, [ colon_dash | colon_dash body ], dot;

        RuleVariables vars = new RuleVariables();

        String ruleName = readConstant();
        ArrayList<Term> head = parseTermList();

        addVarsFromTermList(vars.otherVariables, head);

        boolean onlyConstants = true;
        for (Term t : head) {
            onlyConstants &= t.isConstant();
        }

        Token afterHead = expect(Token.colon_dash, Token.dot);
        if (onlyConstants && (afterHead == Token.dot || is(Token.dot))) {
            ArrayList<Attribute> termStrings = new ArrayList<>();
            for (Term term : head) {
                termStrings.add(new Attribute(term.getTerm()));
            }
            Relation rel = facts.get(ruleName, head.size());
            if (rel == null) {
                rel = new Relation(ruleName, head.size());
                facts.add(rel);
            }
            rel.add(new Tuple(termStrings));
        } else {
            Rule rule = new Rule(ruleName, head);
            parseRuleBody(rule, vars);
            checkRule(rule, vars);
            rules.add(rule);
        }

        if (afterHead != Token.dot) {
            expect(Token.dot);
        }
    }

    void checkRule(Rule rule, RuleVariables vars) throws ParseException {
        for (String var : vars.otherVariables) {
            if (!vars.positiveLiteralContextVariables.contains(var)) {
                throw new ParseException(lexer.getLine() + ":" + lexer.getColumn() + ": variable \"" + var + "\" does not appear in any positive literal context");
            }
        }
    }

    private Rule parseQuery() throws ParseException {
        // query = constant, term_list, dot
        String ruleName = readConstant();
        ArrayList<Term> head = parseTermList();
        expect(Token.dot);
        return new Rule(ruleName, head);
    }

    private boolean is(Token t) {
        return lexer.getToken() == t;
    }

    private String readConstant() throws ParseException {
        String identStr = lexer.getIdentifierStr();
        expect(Token.constant);
        return identStr;
    }

    private String readVariable() throws ParseException {
        String identStr = lexer.getIdentifierStr();
        expect(Token.variable);
        return identStr;
    }

    private Term readTerm() throws ParseException {
        String identStr = lexer.getIdentifierStr();
        boolean isVariable = expect(Token.constant, Token.variable) == Token.variable;
        if (isVariable && identStr.equals("_")) {
            identStr = null;
        }
        return new Term(identStr, isVariable);
    }

    private Token expect_peek(Token... tokens) throws ParseException {
        for (Token token : tokens) {
            if (lexer.getToken() == token) {
                return token;
            }
        }
        StringBuilder errMsg = new StringBuilder();
        errMsg.append(lexer.getLine());
        errMsg.append(':');
        errMsg.append(lexer.getColumn());
        errMsg.append(": unexpected ");
        errMsg.append(lexer.getToken().toString());
        errMsg.append(", expected ");
        for (int i = 0; i < tokens.length; i++) {
            if (i == 0) {
                //
            } else if (i == tokens.length - 1) {
                errMsg.append(" or ");
            } else {
                errMsg.append(", ");
            }
            errMsg.append(tokens[i].toString());
        }
        throw new ParseException(errMsg.toString());
    }

    private Token expect(Token... tokens) throws  ParseException {
        Token t = expect_peek(tokens);
        lexer.next();
        return t;
    }

    public class ParseException extends Exception {
        public ParseException(String message) {
            super(message);
        }
    }

    private class RuleVariables {
        public HashSet<String> positiveLiteralContextVariables = new HashSet<>();
        public HashSet<String> otherVariables = new HashSet<>();
    }

    private static void addVar(HashSet<String> set, Term t) {
        if (t.isVariable()) {
            set.add(t.getTerm() != null ? t.getTerm() : "_");
        }
    }

    private static void addVarsFromTermList(HashSet<String> set, List<Term> terms) {
        for (Term t : terms) {
            addVar(set, t);
        }
    }
}
