from pyrsistent import pmap, pset, freeze
from graph import Graph
from types import MethodType

class LabeledGraph:
    def __init__(self, _graph = Graph(), _labels = pmap()):
        self._graph = _graph
        self._labels = _labels
        
    def _maybemakenew(self, graph, labels):
        if graph == self._graph and labels == self._labels: return self
        return LabeledGraph(graph, labels)
                     
    def hasvertex(self, v):
        return self._graph.hasvertex(v)
                            
    def hasedge(self, v1, v2):
        return self._graph.hasedge(v1, v2)
        
    def vertices(self):
        return self._graph.vertices()

    def neighbours(self, v):
        return self._graph.neighbours(v)
            
    def addvertex(self, v):
        return self._maybemakenew(self._graph.addvertex(v), self._labels)

    def addedge(self, v1, v2):
        return self._maybemakenew(self._graph.addedge(v1, v2), self._labels)
                    
    def removeedge(self, v1, v2):
        if not self._graph.hasedge(v1, v2): return self
        return LabeledGraph(self._graph.removeedge(v1, v2), self._labels.discard(pset({v1, v2})))
        
    def removevertex(self, v):
        if not self._graph.hasvertex(v): return self
        evo = self._labels.evolver()
        if pset({v}) in evo: evo.remove(pset({v}))
        for v2 in self._graph.neighbours(v):
            if pset({v, v2}) in evo: evo.remove(pset({v, v2}))
        return LabeledGraph(self._graph.removevertex(v), evo.persistent())
        
    def setvertexlabel(self, v, label):
        if not self.hasvertex(v): raise KeyError(f"Graph does not have vertex {v}.")
        return self._maybemakenew(self._graph, self._labels.set(pset({v}), label))

    def getvertexlabel(self, v):
        return self._labels[pset({v})]
    
    def removevertexlabel(self, v):
        return self._maybemakenew(self._graph, self._labels.discard(pset({v})))

    def hasvertexlabel(self, v):
        return pset({v}) in self._labels

    def setedgelabel(self, v1, v2, label):
        if not self.hasedge(v1, v2): raise KeyError(f"Graph does not have edge {v1}, {v2}.")
        return self._maybemakenew(self._graph, self._labels.set(pset({v1, v2}), label))

    def getedgelabel(self, v1, v2):
        return self._labels[pset({v1, v2})]
    
    def hasedgelabel(self, v1, v2):
        return pset({v1, v2}) in self._labels

    def removeedgelabel(self, v1, v2):
        return self._maybemakenew(self._graph, self._labels.discard(pset({v1, v2})))

    def __getitem__(self, items):
        try:    #check if items is iterable
            _ = (e for e in items)
        except TypeError:
            raise KeyError
        return self._labels[pset(items)]
        
    def transform(self, *args):
        newargs = [freeze(x) if i%2 == 0 else x for i, x in enumerate(args)]
        return self._maybemakenew(self._graph, self._labels.transform(*newargs))
     
     
    class _Evolver:
        def __init__(self, lgraph):
            self._lgraph = lgraph
            self._evo = lgraph._labels.evolver()

        def persistent(self):
            return self._lgraph._maybemakenew(self._lgraph._graph, self._evo.persistent())

        #this takes care of many functions, but we have to solve magic ones
        def __getattr__(self, attr):  
            return getattr(self._evo, attr)
        
        def __getitem__(self, idx):
            return self._evo[idx]

        def __setitem__(self, idx, value):
            self._evo[idx] = value
    
    def evolver(self):
        return self._Evolver(self)
        
    