import networkx as nx
import copy
import os
from ctypes import cdll
import time
from subprocess import DEVNULL, STDOUT, check_call

cgraph = cdll.LoadLibrary('../gen_c/jmgraph.so')


def returnSnarks():
    snarkList = []

    gr6List = [
    b"I?h]@eOWG",
    b"Q?hY@eOGG??B_??@g???T?a??@g",
    b"Q?gY@eOGGC?B_??@g_??DO?O?GW",
    b"S?hW@eOGG??B_?O@g???C?a???wC@??Oc",
    b"S?hW@eOGGC?A_A?@gO??GO???GW?AO?@c",
    b"S?GW@E?GG?GB_AO_g_?CP_??P?G`??O?S",
    b"S?`W@e?GG?GA_A??g?GCP?_O?KHO??`?C",
    b"S?hW@eOGG?GB_A?_G?GC???_??w?O_?AS",
    b"S?gQ@eOOGC?AP??BO@@?GB????o?E???["]

    for gr6 in gr6List:
        snarkList.append(nx.from_graph6_bytes(gr6))

    return snarkList


def saveGraphContainer(graphContainer, name):
    file = open(name, "w")
    for set in graphContainer:
        for gr in set:
            file.write(str(gr.edges()) + "\n")


def returnSimpleCanonic(graph):
    testGraph = nx.Graph()
    testGraph.add_edges_from(graph.edges())
    gr6format = (nx.to_graph6_bytes(testGraph, header=False)).decode()

    with open('../gen_c/nauty/canon_input', 'w') as fileIn:
        fileIn.write(gr6format)

    check_call(['../gen_c/nauty/labelg', '../gen_c/nauty/canon_input', '../gen_c/nauty/canon_output'], stdout=DEVNULL,
               stderr=DEVNULL)

    with open('../gen_c/nauty/canon_output', 'r') as fileOut:
        output = fileOut.readline()

    return output


def insertGraph(container, canonicCont, graph):
    if (graph == None):
        return
    size = graph.number_of_nodes()

    relabelledGraph = nx.convert_node_labels_to_integers(graph, 1)
    canonic = returnSimpleCanonic(relabelledGraph)

    if ((canonic not in canonicCont[(size // 2) - 1]) and (not isColorable(relabelledGraph))):
        canonicCont[(size // 2) - 1].append(canonic)
        container[(size // 2) - 1].append(relabelledGraph)


def simplifyForCore(graph):
    graphCopy = nx.MultiGraph()
    graphCopy.add_edges_from(graph.edges())
    for node in graph.nodes:
        if (graph.has_edge(node, node)):
            graphCopy.remove_node(node)
    return graphCopy


def insertGraphCores(coreContainer, canonicCoreCont, graph, depth):

    graphCopy = nx.Graph()
    graphCopy.add_edges_from(graph.edges())
    if (len(graphCopy.edges()) == 0):

        return

    if (isColorable(graphCopy)):
        return
    subGraphsAreColourable = True

    for u, v in graphCopy.edges:
        graphCopy2 = nx.Graph()
        graphCopy2.add_edges_from(graph.edges())
        graphCopy2.remove_edge(u, v)

        for c in nx.connected_components(graphCopy2):
            subGr = graphCopy2.subgraph(c).copy()


            if (not len(subGr.edges()) == 0 and (not isColorable(subGr))):

                subGraphsAreColourable = False

                insertGraphCores(coreContainer, canonicCoreCont, subGr, depth + 1)

        if (subGraphsAreColourable):
            insertGraph(coreContainer, canonicCoreCont, graphCopy)


def insertCores(coreContainer, canonicCoreCont, graph):

    simpleGraph = simplifyForCore(graph)
    insertGraphCores(coreContainer, canonicCoreCont, simpleGraph, 0)

    # print("Finished inserting cores at --- {} seconds ---".format((time.time() - startT)))


def addGraphToCoreContainers(coreContainer, canonicCoreCont, container):
    print("Adding cores from container")
    i = 0
    for set in container:
        start_time = time.time()
        i = i + 1
        for graph in set:
            insertCores(coreContainer, canonicCoreCont, graph)
        print("Finished core insertion {} --- {} seconds ---".format(2 * i, (time.time() - start_time)))


def cubicCheck(graph):
    for n in graph.nodes:
        if (graph.degree[n] != 3):
            raise ValueError('Not a cubic graph')


def isColorable(graph):
    if (nx.number_of_selfloops(graph) > 0):
        return False
    simpleGraph = nx.Graph()
    simpleGraph.add_edges_from(graph.edges())
    # print(simpleGraph.edges())
    bytesOfSimple = nx.to_graph6_bytes(simpleGraph, header=False)
    # print(bytesOfSimple)
    jmgraph = cgraph.readGraph6(bytesOfSimple.rstrip())
    # print(jmgraph)

    return cgraph.isColourable(jmgraph)


def addTwoNodesOnEdge(graph, edge):
    n = graph.number_of_nodes()
    g1 = copy.deepcopy(graph)
    g1.remove_edge(edge[0], edge[1])
    g1.add_node(n + 1)
    g1.add_node(n + 2)
    g1.add_edge(edge[0], n + 1)
    g1.add_edge(edge[1], n + 2)
    g1.add_edge(n + 1, n + 2)
    g1.add_edge(n + 1, n + 2)
    return g1


def addTwoLoopsInsteadEdge(graph, edge):
    n = graph.number_of_nodes()
    g1 = copy.deepcopy(graph)
    g1.remove_edge(edge[0], edge[1])
    g1.add_node(n + 1)
    g1.add_node(n + 2)
    g1.add_edge(edge[0], n + 1)
    g1.add_edge(edge[1], n + 2)
    g1.add_edge(n + 1, n + 1)
    g1.add_edge(n + 2, n + 2)

    cubicCheck(g1)
    if (nx.is_connected(g1)):
        return g1
    else:
        return None


def addTriangle(graph, node):
    neighborEdges = graph.edges(node)
    neighborNodes = []
    for nbEdge in neighborEdges:
        if (nbEdge[0] == node):
            neighborNodes.append(nbEdge[1])
        else:
            neighborNodes.append(nbEdge[0])
    num = graph.number_of_nodes()
    g1 = copy.deepcopy(graph)
    nb = list(g1[node])
    numOfNeighbors = len(nb)
    g1.add_edge(node, num + 1)
    g1.add_edge(node, num + 2)
    g1.add_edge(num + 1, num + 2)

    hasSelfLoop = node in nb

    if (numOfNeighbors == 1):
        g1.add_edges_from([(neighborNodes[0], node), (neighborNodes[0], num + 1), (neighborNodes[0], num + 2)])
    if (numOfNeighbors == 2):
        if (hasSelfLoop):
            g1.add_edges_from([(neighborNodes[0], num + 1), (neighborNodes[1], num + 2)])
        else:
            g1.add_edges_from([(neighborNodes[0], node), (neighborNodes[1], num + 1), (neighborNodes[2], num + 2)])
    if (numOfNeighbors == 3):
        g1.add_edges_from([(neighborNodes[0], node), (neighborNodes[1], num + 1), (neighborNodes[2], num + 2)])

    for n in neighborNodes:
        g1.remove_edge(n, node)

    cubicCheck(g1)
    return g1


def addSquare(graph, edge1, edge2):
    n = graph.number_of_nodes()
    g1 = copy.deepcopy(graph)
    g1.add_edge(n + 1, n + 2)
    g1.add_edge(n + 2, n + 3)
    g1.add_edge(n + 3, n + 4)
    g1.add_edge(n + 1, n + 4)

    g1.remove_edge(edge1[0], edge1[1])
    g1.remove_edge(edge2[0], edge2[1])

    g1.add_edge(edge1[0], n + 1)
    g1.add_edge(edge1[1], n + 4)
    g1.add_edge(edge2[0], n + 2)
    g1.add_edge(edge2[1], n + 3)

    return g1


def joinGraphsOnEdges(g1, g2, e1, e2):
    if (e1[0] == e1[1] or e2[0] == e2[1]):
        return None

    gr1 = copy.deepcopy(g1)
    gr2 = copy.deepcopy(g2)
    num1 = g1.number_of_nodes()
    gr1.remove_edge(e1[0], e1[1])
    gr2.remove_edge(e2[0], e2[1])
    gr2aux = nx.relabel_nodes(gr2, lambda x: x + num1)
    newGraph = nx.compose(gr1, gr2aux)
    newGraph.add_edges_from([(e1[0], e2[0] + num1), (e1[1], e2[1] + num1)])

    return newGraph


def joinGraphsOnNodes(g1, g2, n1, n2):
    neighbourNum1 = len((list(g1.neighbors(n1))))
    neighbourNum2 = len((list(g2.neighbors(n2))))

    gr1 = copy.deepcopy(g1)
    gr2 = copy.deepcopy(g2)

    if (neighbourNum1 == 1 or neighbourNum2 == 1):
        return None

    nodeOffset = g1.number_of_nodes()
    n2 = n2 + nodeOffset
    gr2aux = nx.relabel_nodes(gr2, lambda x: x + nodeOffset)

    nb1 = list(gr1.neighbors(n1))
    nb2 = list(gr2aux.neighbors(n2))

    hasSelfLoop1 = n1 in nb1
    hasSelfLoop2 = n2 in nb2

    if (hasSelfLoop1):
        gr1.remove_edge(n1, n1)
        nb1.remove(n1)
    if (hasSelfLoop2):
        gr2aux.remove_edge(n2, n2)
        nb2.remove(n2)
    gr1.remove_node(n1)
    gr2aux.remove_node(n2)

    newGraph = nx.compose(gr1, gr2aux)
    if (hasSelfLoop1 and hasSelfLoop2):
        newGraph.add_edge(nb1[0], nb2[0])
        newGraphRelab = nx.convert_node_labels_to_integers(newGraph, 1)
        cubicCheck(newGraphRelab)

        return newGraphRelab

    if (neighbourNum1 == 3 and neighbourNum2 == 3):
        newGraph.add_edges_from([(nb1[0], nb2[0]), (nb1[1], nb2[1]), (nb1[2], nb2[2])])
        newGraph = nx.convert_node_labels_to_integers(newGraph, 1)
        cubicCheck(newGraph)

        return newGraph
    return None


def test(graph):
    bytesOfSimple = nx.to_graph6_bytes(graph, header=False)
    jmgraph = cgraph.readGraph6(bytesOfSimple.rstrip())
    pointer = cgraph.testFormula(jmgraph)

    print(pointer)

