#ifndef BA_GRAPH_MULTIPOLE_MROZEK_HPP
#define BA_GRAPH_MULTIPOLE_MROZEK_HPP

#include <set>
#include <stdexcept>
#include <vector>

#include <snarks/colouring_pd.hpp>

#include <algorithms/path_decomposition/shortest_path_heuristic.hpp>
#include <invariants/connectivity.hpp>
#include <io/print_nice.hpp>
#include <util/math.hpp>

using Colouring = std::vector<int>;
using ColourSet = std::vector<Colouring>;

inline int factorial(int x) { return (x == 1 ? x : x * factorial(x - 1)); }

namespace ba_graph {

namespace internal {

// takes a colouring - returns equivalent lexically minimal colouring
Colouring renumberColouring(Colouring colouring) {
  int n[3];
  std::fill_n(n, 3, -1);
  for (int &i : colouring) {
    for (int j = 0; j < 3; j++) {
      if (n[j] == -1)
        n[j] = i;
      if (n[j] == i) {
        i = j;
        break;
      }
    }
  }
  return colouring;
}

// takes colour as an integer (decimal) and length of colouring (to fill zeros)
// returns minimal colouring
// converts 'colour' from decimal to ternary (last digit has highest value)
Colouring getColouring(int colour, int length) {
  Colouring res;
  res.resize(length);
  for (int i = length - 1; i >= 0; i--) {
    int x = colour / (power(3, i));
    res[i] = x;
    colour = colour % (power(3, i));
  }
  return res;
}

// takes colouring, returns in decimal
// converts 'colouring' from ternary (last digit has highest value) to decimal,
int colouringToDecimal(Colouring &colouring) {
  uint_fast64_t num = 0;
  for (int i = colouring.size() - 1; i >= 0; i--) {
    num += colouring[i] * power(3, i);
  }
  return num;
}

// takes and returns a vector of colour sets
// removes duplicates and sorts the colourings
ColourSet filterColourSet(ColourSet &colourings) {
  std::set<Colouring> s(colourings.begin(), colourings.end());
  ColourSet res(s.begin(), s.end());
  return res;
}

} // namespace internal

// takes ColouringBitArray, returns a ColourSet
// resulting colourings have no 'color duplicates' - getColouring +
// filterColourSet
ColourSet colouringBitArrayToColourSet(ColouringBitArray &cba) {
  ColourSet res;
  int colour = 0;
  int size = cba.size().to_int64();
  int length = 0;
  while (size > 1) {
    length++;
    size = size / 3;
  }
  for (auto i = ColouringBitArray::Index(0, 0); i < cba.size(); i++) {
    if (cba.get(i)) {
      Colouring colouring = internal::getColouring(colour, length);
      res.push_back(internal::renumberColouring(colouring));
    }
    colour++;
  }

  return internal::filterColourSet(res);
}

// takes ColourSet and returns an equivalent ColouringBitArray
ColouringBitArray colourSetToBitArray(ColourSet &colourSet) {
  if (colourSet.size() < 1)
    throw std::runtime_error("Empty ColourSet");
  auto cba = ColouringBitArray(power(3, colourSet[0].size()), false);
  for (auto &colouring : colourSet) {
    uint_fast64_t num = internal::colouringToDecimal(colouring);
    auto i = ColouringBitArray::Index::to_index(num);
    cba.set(i, true);
  }
  return std::move(cba);
}

namespace internal {

// takes multipole, returns and array of Vertices of degree 1,
// which in our case represent dangling edges
std::vector<Vertex> getTerminals(Graph &g) {
  std::vector<Vertex> terminals;
  for (auto &rotation : g) {
    if (rotation.degree() == 1) {
      terminals.push_back(rotation.v());
    }
  }
  return terminals;
}

// takes multipole and returns created edges to remove later
// needed for the path decomposition algorithm
std::vector<Edge> addCycleToTerminals(Graph &g) {
  auto terminals = getTerminals(g);
  std::vector<Edge> createdEdges;

  auto length = terminals.size();
  createdEdges.emplace_back(
      std::move(addE(g, terminals[0], terminals[length - 1]).e()));
  if (length == 1)
    return createdEdges;
  for (uint i = 0; i < length - 1; i++) {
    createdEdges.emplace_back(
        std::move(addE(g, terminals[i], terminals[i + 1]).e()));
  }

  return createdEdges;
}

// takes multipole, path decomposition and edges created in addCycleToTerminals
// removes the cycle after calling path decomposition algorithm
void removeCycleFromTerminals(Graph &g, std::vector<Edge> &pathDecomposition,
                              std::vector<Edge> &createdEdges) {
  // delete cycle from multipole
  for (auto &e : createdEdges) {
    deleteE(g, e);
  }
  // delete cycle from ordered edges
  for (auto &e : createdEdges) {
    int i = 0;
    for (auto &p : pathDecomposition) {
      if (p == e) {
        pathDecomposition.erase(pathDecomposition.begin() + i);
        break;
      }
      i++;
    }
  }
}

// takes multipole and vertices (terminals),
// for each vertex v it creates a new vertex and connects to v with 2 edges
// returns created edges
// is needed for colouring algorithm to work
std::vector<Edge> addMultiedges(Graph &g,
                                std::vector<Vertex> selectedVertices) {
  int j = g.order(); // assumes edges go from 0 to order - 1
  std::vector<Edge> res;
  for (auto i : selectedVertices) {
    auto &r = addV(g, j);
    res.push_back(addE(g, r.v(), i).e());
    res.push_back(addE(g, r.v(), i).e());
    j++;
  }
  return res;
}

// takes multipole
// removes vertices of degree 2 (ie. those from addMultiedges)
void removeMultiedges(Graph &g) {
  std::vector<Vertex> vr;
  for (auto &r : g) {
    if (r.degree() == 2) {
      vr.push_back(r.v());
    }
  }
  for (auto &v : vr) {
    deleteV(g, v);
  }
}

// takes ColouringBitArray, returns ColouringBitArray
// by using ColouringBitArrayToColourSet we get unique and minimal colourings,
// which means after simple conversion back we get minimal ColouringBitArray
ColouringBitArray minimalColouringBitArray(ColouringBitArray &cba) {
  auto temp = colouringBitArrayToColourSet(cba);
  auto t = colourSetToBitArray(temp);
  return t;
}

// takes multipole and pathDecomposition
// returns colourSet for the multipole
ColourSet getColourSet(Graph &g, std::vector<Edge> &pathDecomposition) {
  PDColorizer pathDecompositionColorizer;
  pathDecompositionColorizer.initialize(g);
  for (auto edge : pathDecomposition)
    pathDecompositionColorizer.process_state(edge.v1(), edge.v2());
  auto &state = pathDecompositionColorizer.state;
  // TODO: probably possible to just create new CBA and do "CBA += state[i]"
  // and then "return colouringBitArrayToColourSet(CBA)"
  std::vector<ColouringBitArray> allColourings;
  for (unsigned int i = 0; i < state.size(); i++)
    if (!state[i].all_false())
      allColourings.emplace_back(state[i]);

  ColourSet res;
  for (auto &c : allColourings) {
    auto x = colouringBitArrayToColourSet(c);
    res.insert(res.end(), x.begin(), x.end());
  }

  return internal::filterColourSet(res);
}

// path_decomposition may not work well with cycles/multiedge
void throwForSmallWidthMultipoles(Graph &g) {
  auto terminals = g.list(RP::degree(1), RT::n());
  if (terminals.size() == 0)
    throw std::runtime_error("Graph has no semiedge");
  if (terminals.size() == 1) {
    throw std::runtime_error("1-poles have no colouring");
  }
}

// needed to use path decomposition algorithms
std::vector<Edge> setup(Graph &g) {
  // need to add the cycle, because path decomposition right now
  // only works with cubic graphs
  auto createdEdges = internal::addCycleToTerminals(g);

  // works only for connected cubic bridgeless graphs
  // if it has a bridge, throw error
  if (has_cut_edge(g)) {
    throw std::runtime_error(
        "Current implementation does not work with bridges.");
  }

  // path decomposition is used to get ordered edges,
  // that contain all vertices and are part of one path
  // this way we can get them in cannonical
  auto pathDecomposition = shortest_path_heuristic(g, g[0][0].e());
  internal::removeCycleFromTerminals(g, pathDecomposition.ordered_edges,
                                     createdEdges);
  return pathDecomposition.ordered_edges;
}

} // namespace internal

// takes multipole
// returns ColourSet for the multipole
ColourSet getColourSet(Graph &g) {
  internal::throwForSmallWidthMultipoles(g);

  auto edges = internal::setup(g);
  // need to add multiedges, because PDColourizer works better
  // with vertices of order 2 and 3
  auto terminals = internal::getTerminals(g);
  auto multiedges = internal::addMultiedges(g, terminals);
  edges.insert(edges.end(), multiedges.begin(), multiedges.end());
  // end of setup

  auto res = internal::getColourSet(g, edges);

  // cleanup
  internal::removeMultiedges(g);
  return res;
}

// takes multipole and subset of terminals
// returns ColourSet, ignoring not selectedTerminals
ColourSet getColourSet(Graph &g, std::vector<Vertex> selectedTerminals) {
  internal::throwForSmallWidthMultipoles(g);

  auto edges = internal::setup(g);
  // need to add multiedges, because PDColourizer works better
  // with vertices of order 2 and 3
  auto multiedges = internal::addMultiedges(g, selectedTerminals);
  edges.insert(edges.end(), multiedges.begin(), multiedges.end());
  // end of setup

  auto res = internal::getColourSet(g, edges);

  // cleanup
  internal::removeMultiedges(g);
  return res;
}

// takes multipole
// returns minimal ColouringBitArray
ColouringBitArray getAllColourings(Graph &g) {
  auto colourSet = getColourSet(g);
  auto cba = colourSetToBitArray(colourSet);
  auto res = internal::minimalColouringBitArray(cba);
  return res;
}

// #################### permutations
namespace internal {

// takes ColouringBitArray and index to premute
// moves edge from index to end
// returns ColouringBitArray
ColouringBitArray nextPermutation(ColouringBitArray &cba, int index) {
  auto [a, b, c] = cba.split3(index);
  a.concatenate_to_special(b);
  a.concatenate_to_special(c);
  cba = a;
  return cba;
}

// takes terminalIds and id to permute
// returns next permutation of terminalIds
std::vector<int> nextPermutation(std::vector<int> &terminalIds, int id) {
  auto x = terminalIds[id];
  terminalIds.erase(terminalIds.begin() + id);
  terminalIds.push_back(x);
  return terminalIds;
}

// takes terminalIds, ColouringBitArray, id of element to permute,
// and terminalPermutations and cbaPermutations to store return values
// calculates all permutations for terminals and ColouringBitArrays
// so that at index i, terminalPermutations[i] would have
// ColouringBitArray cbaPermutations[i]
void permutationsRecursive(std::vector<int> &terminalIds,
                           ColouringBitArray &cba, int id,
                           std::vector<std::vector<int>> &terminalPermutations,
                           std::vector<ColouringBitArray> &cbaPermutations) {
  if (id == terminalIds.size() - 1) {
    terminalPermutations.emplace_back(terminalIds);
    nextPermutation(terminalIds, id);
    cbaPermutations.emplace_back(cba);
    nextPermutation(cba, id);
  } else {
    for (int i = 0; i < terminalIds.size() - id; i++) {
      permutationsRecursive(terminalIds, cba, id + 1, terminalPermutations,
                            cbaPermutations);
      nextPermutation(terminalIds, id);
      nextPermutation(cba, id);
    }
  }
}

// takes terminalIds and ColouringBitArray
// returns all pair of all terminalIds and cba permutations,
// so that at index i, terminalPermutations[i] would have
// ColouringBitArray cbaPermutations[i]
std::pair<std::vector<std::vector<int>>, std::vector<ColouringBitArray>>
permutations(std::vector<int> terminalIds, ColouringBitArray cba) {
  std::vector<ColouringBitArray> cbaPermutations;
  std::vector<std::vector<int>> terminalPermutations;
  permutationsRecursive(terminalIds, cba, 0, terminalPermutations,
                        cbaPermutations);
  return std::pair(terminalPermutations, cbaPermutations);
}

} // namespace internal

// takes terminalIds and ColouringBitArray
// returns a pair - minimal ColourSet from all permutations and
// terminalIds in corresponding order
std::pair<std::vector<int>, ColourSet>
getCanonicalColourSet(std::vector<int> &terminalIds, ColouringBitArray cba) {
  cba = internal::minimalColouringBitArray(cba);
  auto permutations = internal::permutations(terminalIds, cba);

  std::vector<ColouringBitArray> bitArrayVector;
  for (auto &c : permutations.second) {
    bitArrayVector.emplace_back(c);
  }
  sort(bitArrayVector.begin(), bitArrayVector.end());

  if (bitArrayVector.size() == 0) {
    throw std::runtime_error("No colouring.");
  }
  auto &res = bitArrayVector[0];
  auto it =
      std::find(permutations.second.begin(), permutations.second.end(), res);
  int id = it - permutations.second.begin();

  ColourSet canonicalColourSet = colouringBitArrayToColourSet(res);
  return std::pair(permutations.first[id], canonicalColourSet);
}

// takes terminalIds and multipole
// returns a pair - minimal ColourSet from all permutations and
// terminalIds in corresponding order
std::pair<std::vector<int>, ColourSet>
getCanonicalColourSet(std::vector<int> &terminalIds, Graph &g) {
  ColouringBitArray cba = getAllColourings(g);
  if (cba.all_false()) {
    throw std::runtime_error("No colouring.");
  }

  return getCanonicalColourSet(terminalIds, cba);
}

// takes a multipole
// returns a pair - minimal ColourSet from all permutations and
// terminalIds in corresponding order
std::pair<std::vector<int>, ColourSet> getCanonicalColourSet(Graph &g) {
  ColouringBitArray cba = getAllColourings(g);
  if (cba.all_false()) {
    throw std::runtime_error("No colouring.");
  }
  auto terminals = internal::getTerminals(g);
  std::vector<int> terminalIds;
  for (auto &vertex : terminals) {
    terminalIds.push_back(vertex.to_int());
  }

  return getCanonicalColourSet(terminalIds, cba);
}

} // namespace ba_graph
#endif
