from labeled_graph import LabeledGraph
from test_graph import test_addvertex, test_addedge, test_removevertex, test_removeedge

def assertkeyerror(fn, *args):
    try:
        fn(*args)
    except KeyError:
        return
    assert(False)

def assertkeyerror2(fn, arg):
    try:
        fn[arg]
    except KeyError:
        return
    assert(False)

def test_vertexlabels():
    g1 = LabeledGraph().addvertex(1).addvertex(2).addvertex(3)
    assert(g1.removevertexlabel(1) == g1)
    assert(g1.removevertexlabel(4) == g1)
    assert(not g1.hasvertexlabel(1))
    assert(not g1.hasvertexlabel(4))
    assertkeyerror(g1.getvertexlabel, 1)
    assertkeyerror(g1.getvertexlabel, 4)
    assertkeyerror(g1.setvertexlabel, 4, None)
    obj = object()
    g2 = g1.setvertexlabel(1, obj)
    assertkeyerror(g1.getvertexlabel, 1)
    assert(g2.getvertexlabel(1) == obj)
    assert(g2.hasvertexlabel(1))
    assertkeyerror(g2.getvertexlabel, 2)
    obj2 = object()
    g3 = g2.setvertexlabel(1, obj2)
    assert(g2.getvertexlabel(1) == obj)
    assert(g3.getvertexlabel(1) == obj2)
    assert(g3.setvertexlabel(1, obj2) == g3)
    g4 = g3.removevertexlabel(1)
    assertkeyerror(g4.getvertexlabel, 1)
    
def test_edgelabels():
    g1 = LabeledGraph().addvertex(1).addvertex(2).addvertex(3).addedge(1, 2).addedge(2, 3)
    assert(g1.removeedgelabel(1, 2) == g1)
    assert(g1.removeedgelabel(4, 3) == g1)
    assert(not g1.hasedgelabel(1, 2))
    assert(not g1.hasedgelabel(4, 5))
    assertkeyerror(g1.getedgelabel, 1, 2)
    assertkeyerror(g1.getedgelabel, 4, 3)
    assertkeyerror(g1.setedgelabel, 4, 3, None)
    obj = object()
    g2 = g1.setedgelabel(1, 2, obj)
    assertkeyerror(g1.getedgelabel, 1, 2)
    assert(g2.getedgelabel(1, 2) == obj)
    assert(g2.getedgelabel(2, 1) == obj)
    assert(g2.hasedgelabel(1, 2))
    assert(g2.hasedgelabel(2, 1))
    assertkeyerror(g2.getedgelabel, 2, 3)
    obj2 = object()
    g3 = g2.setedgelabel(1, 2, obj2)
    assert(g2.getedgelabel(1, 2) == obj)
    assert(g3.getedgelabel(1, 2) == obj2)
    assert(g3.setedgelabel(1, 2, obj2) == g3)
    g4 = g3.removeedgelabel(1, 2)
    assertkeyerror(g4.getedgelabel, 1, 2)
    
def test_deletinggraphandlabels():
    g1 = LabeledGraph().addvertex(1).addvertex(2).addvertex(3).addedge(1, 2).addedge(2, 3) \
                       .setvertexlabel(1, "1").setvertexlabel(2, "2") \
                       .setedgelabel(1, 2, "12").setedgelabel(2, 3, "23")
    g2 = g1.removeedge(1, 2)
    assert(g2.getedgelabel(2,3) == "23")
    assert(g2.getvertexlabel(2) == "2")
    assertkeyerror(g2.getedgelabel, 1, 2)
    g3 = g1.removevertex(2)
    assertkeyerror(g3.getedgelabel, 1, 2)
    assertkeyerror(g3.getedgelabel, 2, 3)
    assertkeyerror(g3.getvertexlabel, 2)
    assert(g3.getvertexlabel(1) == "1")


def test_getitem():
    g1 = LabeledGraph().addvertex(1).addvertex(2).addvertex(3).addedge(1, 2).addedge(2, 3) \
                       .setvertexlabel(1, "1").setvertexlabel(2, "2") \
                       .setedgelabel(1, 2, "12").setedgelabel(2, 3, "23")
    assert(g1[{1}] == "1")
    assert(g1[{1, 2}] == "12")
    assert(g1[(2, 1)] == "12")
    assertkeyerror2(g1, 1)
    assertkeyerror2(g1, {})
    assertkeyerror2(g1, {4})
    assertkeyerror2(g1, {4, 5})
    assertkeyerror2(g1, {1, 3})
    assertkeyerror2(g1, {1, 3, 4})
    assertkeyerror2(g1, [1, 3, 4])
    
 

if __name__ == "__main__":
    test_addvertex(LabeledGraph)
    test_addedge(LabeledGraph)
    test_removevertex(LabeledGraph)
    test_removeedge(LabeledGraph)
    
    test_vertexlabels()
    test_edgelabels()
    test_deletinggraphandlabels()
    test_getitem()
    
    print("Tests complete.")