#include <impl/basic.hpp>
#include <graphs.hpp>
#include <iostream>
#include <set>
#include <assert.h>
#include "enumerating2bisections.hpp"

using namespace ba_graph;
using namespace std;

void test_all_independent_sets(vector<set<Number>> independent_sets, MaximalIndependentSet maximal_independent_set) {
    bool success = true;
    while((success) && (independent_sets.size() > 0)){
        set<Number> independent_set(maximal_independent_set.get_maximal_independent_set());

        for(auto it = independent_sets.begin(); it != independent_sets.end(); it++){
            if(independent_set == *it){
                independent_sets.erase(it);
                break;
            }
        }
        success = maximal_independent_set.next_maximal_independent_set();
    }
    assert(!success);
    assert(independent_sets.size() == 0);
}

void maximal_independent_set_tests() {
    Graph g1(empty_graph(3));
    MaximalIndependentSet maximal_independent_set1(g1);
    maximal_independent_set1.generate_maximal_independent_set();
    vector<set<Number>> independent_sets1;
    independent_sets1.push_back({ 0, 1, 2 });
    test_all_independent_sets(independent_sets1, maximal_independent_set1);

    addE(g1, Loc(0,1));
    addE(g1, Loc(1,2));
    MaximalIndependentSet maximal_independent_set2(g1);
    maximal_independent_set2.generate_maximal_independent_set();
    vector<set<Number>> independent_sets2;
    independent_sets2.push_back({ 0, 2 });
    independent_sets2.push_back({ 1 });
    test_all_independent_sets(independent_sets2, maximal_independent_set2);

    Graph g2(empty_graph(6));
    addE(g2, Loc(0,1));
    addE(g2, Loc(0,2));
    addE(g2, Loc(1,2));
    addE(g2, Loc(1,3));
    addE(g2, Loc(1,4));
    addE(g2, Loc(2,4));
    addE(g2, Loc(3,4));
    addE(g2, Loc(3,5));
    addE(g2, Loc(4,5));
    MaximalIndependentSet maximal_independent_set3(g2);
    maximal_independent_set3.generate_maximal_independent_set();
    vector<set<Number>> independent_sets3;
    independent_sets3.push_back({ 0, 3 });
    independent_sets3.push_back({ 0, 4 });
    independent_sets3.push_back({ 0, 5 });
    independent_sets3.push_back({ 1, 5 });
    independent_sets3.push_back({ 2, 3 });
    independent_sets3.push_back({ 2, 5 });
    test_all_independent_sets(independent_sets3, maximal_independent_set3);

    addE(g2, Loc(2,3));
    MaximalIndependentSet maximal_independent_set4(g2);
    maximal_independent_set4.generate_maximal_independent_set();
    vector<set<Number>> independent_sets4;
    independent_sets4.push_back({ 0, 3 });
    independent_sets4.push_back({ 0, 4 });
    independent_sets4.push_back({ 0, 5 });
    independent_sets4.push_back({ 1, 5 });
    independent_sets4.push_back({ 2, 5 });
    test_all_independent_sets(independent_sets4, maximal_independent_set4);

    cout << "SUCCESS - Maximal independent set tests" << endl;
}

bool check_colouring(const Graph &g, Graph3Colouring colours){
    for(auto it = g.begin(); it != g.end(); it++){
        Number colour = colours.get_colour(it->n());
        if(colour <= 0 || colour >= 4){
            return false;
        }
        
        for(auto neighbour : g[it->n()].neighbours()){
            if(colour == colours.get_colour(neighbour)){
                return false;
            }
        }
    }

    return true;
}

void graph_colouring_test() {
    Graph g1(empty_graph(3));
    Graph3Colouring colours1(g1);
    assert(colours1.colour_graph());
    assert(check_colouring(g1, colours1));

    addE(g1, Loc(0,1));
    addE(g1, Loc(1,2));
    Graph3Colouring colours2(g1);
    assert(colours2.colour_graph());
    assert(check_colouring(g1, colours2));

    Graph g2(complete_graph(4));
    Graph3Colouring colours3(g2);
    assert(!colours3.colour_graph());

    Graph g3(empty_graph(6));
    addE(g3, Loc(0,1));
    addE(g3, Loc(0,2));
    addE(g3, Loc(1,2));
    addE(g3, Loc(1,3));
    addE(g3, Loc(1,4));
    addE(g3, Loc(2,4));
    addE(g3, Loc(3,4));
    addE(g3, Loc(3,5));
    addE(g3, Loc(4,5));
    Graph3Colouring colours4(g3);
    assert(colours4.colour_graph());
    assert(check_colouring(g3, colours4));

    addE(g3, Loc(2,3));
    Graph3Colouring colours5(g3);
    assert(!colours5.colour_graph());

    cout << "SUCCESS - Proper 3-vertex colouring tests" << endl;
}

void enumerating_bisections_tests() {
    Graph g1(complete_graph(2));
    TwoBisections bisections1(g1);
    assert(bisections1.enumerate() == 4);

    Graph g2(complete_graph(3));
    TwoBisections bisections2(g2);
    assert(bisections2.enumerate() == 6);

    Graph g3(complete_graph(4));
    TwoBisections bisections3(g3);
    assert(bisections3.enumerate() == 0);

    Graph g4(empty_graph(6));
    addE(g4, Loc(0,1));
    addE(g4, Loc(0,2));
    addE(g4, Loc(1,3));
    addE(g4, Loc(1,4));
    addE(g4, Loc(2,3));
    addE(g4, Loc(2,4));
    addE(g4, Loc(3,5));
    addE(g4, Loc(4,5));
    TwoBisections bisections4(g4);
    assert(bisections4.enumerate() == 2);

    cout << "SUCCESS - Enumerating 2-bisections tests" << endl;
}

int main(){
    maximal_independent_set_tests();
    graph_colouring_test();
    enumerating_bisections_tests();
}