from itertools import chain, combinations

from cipher_description_blue import CipherDescriptionBlue2
from ciphers import RES_with_key
from solvatore_blue2 import SolvatoreBlue2


def getSetOfBalancedBits(solver, normrounds, constant_bits, bits_to_test):
    active_bits = [i for i in range(statesize) if i not in constant_bits]
    # Compute Set of balanced bits
    B = []
    for i in bits_to_test:
        if solver.is_bit_balanced(i, normrounds, active_bits):
            B.append(i)
    return B

roundsneeded = {}
distfound = {}
statesize = 4

# memoization for imciphers
memimcipher = {}
# All configurations, for which distinguisher exists
disttofind = [(1, 2, 8), (1, 3, 5), (1, 3, 6), (1, 5, 5), (1, 5, 6), (1, 6, 8), (1, 10, 8), (1, 11, 5), (1, 11, 6), (1, 13, 5), (1, 13, 6), (1, 14, 8), (2, 1, 4), (2, 1, 5), (2, 1, 6), (2, 1, 7), (2, 1, 8), (2, 2, 6), (2, 2, 7), (2, 3, 4), (2, 3, 5), (2, 3, 6), (2, 5, 4), (2, 5, 5), (2, 5, 6), (2, 6, 6), (2, 6, 7), (2, 7, 4), (2, 7, 5), (2, 7, 6), (2, 7, 7), (2, 7, 8), (2, 9, 4), (2, 9, 5), (2, 9, 6), (2, 9, 7), (2, 9, 8), (2, 10, 6), (2, 10, 7), (2, 11, 4), (2, 11, 5), (2, 11, 6), (2, 13, 4), (2, 13, 5), (2, 13, 6), (2, 14, 6), (2, 14, 7), (2, 15, 4), (2, 15, 5), (2, 15, 6), (2, 15, 7), (2, 15, 8), (3, 2, 8), (3, 2, 9), (3, 2, 10), (3, 2, 11), (3, 3, 5), (3, 5, 5), (3, 6, 8), (3, 6, 9), (3, 6, 10), (3, 6, 11), (3, 10, 8), (3, 10, 9), (3, 10, 10), (3, 10, 11), (3, 11, 5), (3, 13, 5), (3, 14, 8), (3, 14, 9), (3, 14, 10), (3, 14, 11)]
disttofind += [(2, 1, 9), (2, 1, 10), (2, 1, 11), (2, 1, 12), (2, 1, 13), (2, 7, 9), (2, 7, 10), (2, 7, 11), (2, 7, 12), (2, 7, 13), (2, 9, 9), (2, 9, 10), (2, 9, 11), (2, 9, 12), (2, 9, 13), (2, 15, 9), (2, 15, 10), (2, 15, 11), (2, 15, 12), (2, 15, 13)]
for ls, ssum, rounds in disttofind:
    print("Trying distinguisher {}".format((ls,ssum, rounds)))
    cipher = RES_with_key.generate_RES(4, ls, ssum)
    for imrounds in range(1, rounds):
        print("Trying {} improved rounds".format(imrounds))
        if (ls,ssum,imrounds) in memimcipher:
            imcipher = memimcipher[(ls,ssum,imrounds)]
        else:
            imcipher = CipherDescriptionBlue2(cipher, rounds=imrounds, anf=True)
            memimcipher[(ls, ssum, imrounds)] = imcipher
        normrounds = rounds - imrounds + 1

        # Checking all possibilities where all bits are either active, blue0 or blue1
        good_indices = []
        for blue_bits in chain.from_iterable(combinations(range(statesize), i) for i in range(1, statesize)):
            for b0 in chain.from_iterable(combinations(blue_bits, i) for i in range(0, len(blue_bits) + 1)):
                blue0 = list(b0)
                blue1 = sorted(set(blue_bits) - set(blue0))
                active = sorted(set(range(statesize)) - (set(blue0) | set(blue1)))
                constants = sorted(set(range(statesize)) - set(active))
                solver = SolvatoreBlue2()
                imcipher.set_blue_bits(blue0, blue1)
                solver.load_simplified_cipher(cipher, imcipher)
                solver.set_rounds(normrounds)
                B = getSetOfBalancedBits(solver, normrounds, constants, range(statesize))
                if len(B) > 0:
                    good_indices.append({"blue0": blue0, "blue1": blue1, "active": active, "balanced": B})
        if len(good_indices) == 0:
            print("No distinguisher exists.")
        else:
            roundsneeded [(ls,ssum,rounds)] = imrounds
            distfound[(ls, ssum, rounds)] = good_indices
            print(good_indices)
            break
print(roundsneeded)
print(distfound)