#include <basic_impl.hpp>

#include <algorithms/cyclic_connectivity.hpp>
#include <util.hpp>

#include <mrozek.hpp>

#include <cassert>

using namespace ba_graph;
using namespace internal;

void test_renumberColouring() {
  Colouring v = {1, 1, 1, 1};
  Colouring res = {0, 0, 0, 0};
  assert(renumberColouring(v) == res);

  v = {1, 1, 2, 2};
  res = {0, 0, 1, 1};
  assert(renumberColouring(v) == res);

  v = {1, 1, 2, 2};
  res = {0, 0, 1, 1};
  assert(renumberColouring(v) == res);
}

void test_getColouring() {
  Colouring res = {0};
  assert(getColouring(0, 1) == res);

  res = {2};
  assert(getColouring(2, 1) == res);

  res = {0, 1};
  assert(getColouring(3, 2) == res);

  res = {0, 2};
  assert(getColouring(6, 2) == res);

  res = {0, 1, 2};
  assert(getColouring(21, 3) == res);

  res = {1, 1, 2, 2};
  assert(getColouring(76, 4) == res);
}

void test_colouringToDecimal() {
  Colouring input = {0, 0};
  assert(colouringToDecimal(input) == 0);

  input = {1, 0};
  assert(colouringToDecimal(input) == 1);

  input = {0, 1};
  assert(colouringToDecimal(input) == 3);

  input = {0, 2};
  assert(colouringToDecimal(input) == 6);

  input = {0, 1, 2};
  assert(colouringToDecimal(input) == 21);

  input = {1, 1, 2, 2};
  assert(colouringToDecimal(input) == 76);
}

void test_colouringBitArrayToColourSet() {
  ColourSet res = {};
  ColouringBitArray input = ColouringBitArray(power(3, 1), false);
  assert(colouringBitArrayToColourSet(input) == res);

  res = {{0, 0, 0}};
  input = ColouringBitArray(power(3, 3), false);
  input.set(ColouringBitArray::Index::to_index(0), true);
  assert(colouringBitArrayToColourSet(input) == res);

  res = {{0, 1, 2}};
  input = ColouringBitArray(power(3, 3), false);
  input.set(ColouringBitArray::Index::to_index(21), true);
  assert(colouringBitArrayToColourSet(input) == res);
}

void test_filterColourSet() {
  ColourSet input = {{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}};
  ColourSet res = {{0, 0, 0, 0}};
  assert(filterColourSet(input) == res);

  input = {{0, 0, 1, 0}, {0, 0, 0, 1}, {0, 0, 2, 0},
           {0, 0, 1, 0}, {0, 0, 1},    {1, 0, 0, 0}};
  res = {{0, 0, 0, 1}, {0, 0, 1}, {0, 0, 1, 0}, {0, 0, 2, 0}, {1, 0, 0, 0}};
  assert(filterColourSet(input) == res);
}

void test_colourSetToBitArray() {
  ColourSet input = {{0, 0, 0}};
  ColouringBitArray res = ColouringBitArray(power(3, 3), false);
  res.set(ColouringBitArray::Index::to_index(0), true);
  assert(colourSetToBitArray(input) == res);

  input = {{0, 0, 1, 1}};
  res = ColouringBitArray(power(3, 4), false);
  res.set(ColouringBitArray::Index::to_index(36), true);
  assert(colourSetToBitArray(input) == res);
}

void test_getTerminals() {
  Graph g(createG());
  addMultipleV(g, 4);
  std::vector<Vertex> res;
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
    res.push_back(g[i].v());
  }
  assert(getTerminals(g) == res);
}

void test_addCycleToTerminals() {
  Graph g(createG());
  addMultipleV(g, 4);
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
  }
  addCycleToTerminals(g);

  for (auto &rotation : g) {
    assert(rotation.degree() == 3);
  }
}

void test_removeCycleFromTerminals() {
  Graph g(createG());
  addMultipleV(g, 4);
  std::vector<Edge> pathDecomposition;
  std::vector<Edge> createdEdges;
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
    createdEdges.push_back(addE(g, Location(i, i % 3 + 1)).e());
  }
  pathDecomposition.push_back(createdEdges[0]);
  removeCycleFromTerminals(g, pathDecomposition, createdEdges);

  assert(pathDecomposition.size() == 0);
  for (int i = 1; i < 4; i++) {
    assert(!g.contains(Location(i, i % 3 + 1)));
  }
}

void test_addMultiedges() {
  Graph g(createG());
  addV(g, 0);

  std::vector<Vertex> selectedVertices;
  selectedVertices.push_back(g[0].v());
  addMultiedges(g, selectedVertices);

  assert(g.order() == 2);
  assert(g[0].degree() == 2);
  assert(g[1].degree() == 2);
}

void test_removeMultiedges() {
  Graph g(createG());
  addMultipleV(g, 3);

  addE(g, Location(0, 1));
  addE(g, Location(1, 2));
  addE(g, Location(1, 2));

  removeMultiedges(g);

  assert(g.order() == 2);
  for (auto &r : g) {
    assert(r.degree() == 1);
  }
}

void test_minimalColouringBitArray() {
  ColouringBitArray input(9, false);
  input.set(ColouringBitArray::Index::to_index(0), true);
  assert(input == minimalColouringBitArray(input));

  input = ColouringBitArray(9, true);
  ColouringBitArray res(9, false);
  res.set(ColouringBitArray::Index::to_index(0), true);
  res.set(ColouringBitArray::Index::to_index(3), true);
  assert(minimalColouringBitArray(input) == res);
}

void test_internal_getColourSet() {
  Graph g(createG());
  addMultipleV(g, 4);
  std::vector<Edge> pathDecomposition;
  for (int i = 1; i < 4; i++) {
    pathDecomposition.push_back(addE(g, Location(0, i)).e());
  }
  auto terminals = getTerminals(g);
  auto multiedges = addMultiedges(g, terminals);
  pathDecomposition.insert(pathDecomposition.end(), multiedges.begin(),
                           multiedges.end());

  ColourSet res = {{0, 1, 2}};
  assert(getColourSet(g, pathDecomposition) == res);
}

void test_throwForSmallWidthMultipoles() {
  Graph g(createG());
  try {
    throwForSmallWidthMultipoles(g);
    assert(false);
  } catch (const std::exception &e) {
    assert(true);
  }

  addMultipleV(g, 4);
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
  }

  try {
    throwForSmallWidthMultipoles(g);
    assert(true);
  } catch (const std::exception &e) {
    assert(false);
  }
}

void test_getColourSet() {
  Graph g(createG());
  addMultipleV(g, 4);
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
  }

  ColourSet res = {{0, 1, 2}};
  assert(getColourSet(g) == res);

  addV(g, 4);
  addV(g, 5);
  addE(g, Location(1, 2));
  addE(g, Location(2, 3));
  addE(g, Location(1, 4));
  addE(g, Location(3, 5));

  res = {{0, 0}};
  assert(getColourSet(g) == res);
}

void test_setup() {
  Graph g(createG());
  addMultipleV(g, 4);
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
  }
  auto orderedEdges = setup(g);
  assert(orderedEdges.size() == 3);
  assert(g.order() == 4);
  assert(g[0].degree() == 3);
  for (int i = 1; i < 4; i++) {
    assert(g[i].degree() == 1);
  }
}

void test_getColourSetSelectedVertices() {
  Graph g(createG());
  addMultipleV(g, 4);
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
  }
  std::vector<Vertex> selectedVertices = {g[1].v(), g[3].v()};

  ColourSet res = {{0, 1}};
  assert(getColourSet(g, selectedVertices) == res);

  addV(g, 4);
  addV(g, 5);
  addE(g, Location(1, 2));
  addE(g, Location(2, 3));
  addE(g, Location(1, 4));
  addE(g, Location(3, 5));

  res = {{0}};
  assert(getColourSet(g, {g[4].v()}) == res);
}

void test_getAllColourings() {
  Graph g(createG());
  addMultipleV(g, 4);
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
  }

  ColouringBitArray res(power(3, 3), false);
  res.set(ColouringBitArray::Index::to_index(21), true);
  assert(getAllColourings(g) == res);
}

void test_permutations() {
  std::vector<int> vertexIds;
  for (int i = 0; i < 8; i++) {
    vertexIds.push_back(i);
    ColouringBitArray cba(power(3, vertexIds.size()), false);
    auto perm = permutations(vertexIds, cba);
    assert(perm.first.size() == factorial(vertexIds.size()));
    std::set<std::vector<int>> permSet(perm.first.begin(), perm.first.end());
    assert(permSet.size() == perm.first.size());
  }
}

void test_getCanonicalColourSet() {
  Graph g(createG());
  addMultipleV(g, 10);
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
  }
  addE(g, Location(1, 2));
  addE(g, Location(2, 3));
  addE(g, Location(1, 4));
  addE(g, Location(3, 5));
  addE(g, Location(4, 6));
  addE(g, Location(4, 7));
  addE(g, Location(5, 8));
  addE(g, Location(5, 9));

  auto terminalsAndColourSet = getCanonicalColourSet(g);
  auto terminals = getTerminals(g);
  std::vector<int> terminalIds;
  for (auto &vertex : terminals) {
    terminalIds.push_back(vertex.to_int());
  }
  auto temp = terminalIds[0];
  terminalIds.erase(terminalIds.begin());
  terminalIds.push_back(temp);

  assert(terminalIds == terminalsAndColourSet.first);
  ColourSet res = {{0, 0, 1, 1}, {0, 1, 0, 1}};
  assert(terminalsAndColourSet.second == res);
}

void test_canCreateSnark() {
  Graph g(createG());
  addMultipleV(g, 10);
  for (int i = 1; i < 4; i++) {
    addE(g, Location(0, i));
  }
  addE(g, Location(1, 2));
  addE(g, Location(2, 3));
  addE(g, Location(1, 4));
  addE(g, Location(3, 5));
  addE(g, Location(4, 6));
  addE(g, Location(4, 7));
  addE(g, Location(5, 8));
  addE(g, Location(5, 9));

  assert(canCreateSnark(g, g).second == false);
}

int main() {
  test_renumberColouring();

  test_getColouring();

  test_colouringToDecimal();

  test_filterColourSet();

  test_colouringBitArrayToColourSet();

  test_colourSetToBitArray();

  test_getTerminals();

  test_addCycleToTerminals();

  test_removeCycleFromTerminals();

  test_addMultiedges();

  test_removeMultiedges();

  test_minimalColouringBitArray();

  test_getAllColourings();

  test_throwForSmallWidthMultipoles();

  test_setup();

  test_getColourSet();

  test_getColourSetSelectedVertices();

  test_getAllColourings();

  test_permutations();

  test_getCanonicalColourSet();

  test_canCreateSnark();
}
