from solvatore import Solvatore
from cipher_description import CipherDescription
from cipher_description_blue import CipherDescriptionBlue2
from itertools import combinations, groupby, chain
from ciphers import RES_with_key, RES_blue_manual
from solvatore_blue2 import SolvatoreBlue2

ls,ssum, rounds = 2, 6, 7
imrounds = 6
normrounds = rounds - imrounds + 1
cipher = RES_with_key.generate_RES(4, ls, ssum)
imcipher = CipherDescriptionBlue2(cipher, rounds=imrounds, anf=True)
statesize = 4

def getSetOfBalancedBits(solver, constant_bits, bits_to_test):
    active_bits = [i for i in range(statesize) if i not in constant_bits]
    # Compute Set of balanced bits
    B = []
    for i in bits_to_test:
        if solver.is_bit_balanced(i, normrounds , active_bits):
            B.append(i)
    return B

# Checking all possibilities where all bits are either active, blue0 or blue1
good_indices = []
for blue_bits in chain.from_iterable(combinations(range(statesize), i) for i in range(1, statesize)):
    for b0 in chain.from_iterable(combinations(blue_bits, i) for i in range(0, len(blue_bits) + 1)):
        blue0 = list(b0)
        blue1 = sorted(set(blue_bits) - set(blue0))
        active = sorted(set(range(statesize)) - (set(blue0) | set(blue1)))
        constants = sorted(set(range(statesize)) - set(active))
        solver = SolvatoreBlue2()
        # imcipher = RES_blue_manual.generate_RES(4, blue0,blue1, ls, ssum)
        imcipher.set_blue_bits(blue0, blue1)
        solver.load_simplified_cipher(cipher, imcipher)
        solver.set_rounds(normrounds)
        B = getSetOfBalancedBits(solver, constants, range(statesize))
        if len(B) > 0:
            good_indices.append({"blue0": blue0, "blue1": blue1, "active": active, "balanced": B})
if len(good_indices) == 0:
    print("No distinguisher exists.")
    exit(1)
print(good_indices)

# # Check all combination of good indices and reduce
while True:
    # Only get the combinations which share balanced bits between
    # the two sets.
    # key is tuple of active, blue1 and blue2 indices, value is balanced bits
    combination_indices = {}
    for comb in combinations(good_indices, 2):
        intersection = set(comb[0]["balanced"]).intersection(set(comb[1]["balanced"]))
        if len(intersection) > 0:
            # get intersections of active and blue bits
            new_active = tuple(sorted(set(comb[0]["active"]).intersection(set(comb[1]["active"]))))
            new_blue0 = tuple(sorted(set(comb[0]["blue0"]).intersection(set(comb[1]["blue0"]))))
            new_blue1 = tuple(sorted(set(comb[0]["blue1"]).intersection(set(comb[1]["blue1"]))))
            newkey = (new_active, new_blue0, new_blue1)
            if newkey not in combination_indices:
                combination_indices[newkey] = intersection

    print("Found {} distinguishers, searching {} combinations:".format(len(good_indices), len(combination_indices)))
    good_indices = []
    # Search through all combinations
    for combo in combination_indices:
        active, blue0, blue1 = combo
        constants = sorted(set(range(statesize)) - set(active))
        balanced = combination_indices[combo]

        # loading cipher and imcipher into solver
        solver = SolvatoreBlue2()
        # imcipher = RES_blue_manual.generate_RES(4, blue0, blue1, ls, ssum)
        imcipher.set_blue_bits(blue0, blue1)
        solver.load_simplified_cipher(cipher, imcipher)
        solver.set_rounds(normrounds)
        B = getSetOfBalancedBits(solver, constants, balanced)
        if len(B) > 0 and combo not in good_indices:
            good_indices.append({"blue0": blue0, "blue1": blue1, "active": active, "balanced": B})
    print(good_indices)
    if len(good_indices) == 0:
        print("Finished Search.")
        break
