import sys
from itertools import chain, combinations
from pycryptosat import Solver
from solvatore import Solvatore

class SolvatoreBlue2(Solvatore):
    def __init__(self):
        super(SolvatoreBlue2, self).__init__()
        self.simpl_cipher = None
        return

    def load_simplified_cipher(self, cipher, simpl_cipher):
        """
        Load cipher description and simpl_cipher description
        for simplified first round to Solvatore
        """
        self.model_size = 0
        self.cipher = cipher
        self.state_size = cipher.state_size
        self.fresh_conditions = False
        self.set_rounds(cipher.rounds)
        self.sbox_cnfs = {}
        self.simpl_cipher = simpl_cipher
        return

    def get_new_variable(self, bit):
        '''
        get_variables split into get_old_variable and get_new_variable
        creates new variable and sets appropriate dictionaries same
        as the get_variables does
        '''
        new_bit = self.next_variable
        self.next_variable += 1
        if bit[0] == 's':
            number = int(bit[1:])
            self.state_bit[number] = new_bit
        elif bit[0] == 't' or bit[0] == 'b':
            self.temporary[bit] = new_bit
        else:
            self.sbox_tmps[bit] = new_bit
        return new_bit

    def apply_add_temp(self, target):
        """
        Add new temporary variable
        """
        new_target = self.get_new_variable(target)
        self.addclause([-new_target])
        return

    def create_conditions(self):
        """
        Create conditions for solver from cipher description
        It differs from original only in that, if there is
        self.simpl_cipher, it will use that description as
        first round and use self.cipher for other ones
        """
        if self.cipher == None:
            print('You need to load a cipher.')
            sys.exit(1)
        if self.rounds == None:
            print('You need to specify the number of rounds.')
            sys.exit(1)
        if self.fresh_conditions:
            return
        self.solver = Solver()
        self.next_variable = 1
        self.state_bit = [i + 1 for i in range(self.state_size)]
        state = [self.state_bit[i] for i in range(self.state_size)]
        self.round_states.append(state)
        self.next_variable += self.state_size
        self.temporary = {i: None for i in self.cipher.temporaries}
        self.sbox_tmps = {}

        normal_rnd_begin = 0
        if self.simpl_cipher:
            # print('Robim logiku pre specialne kolo')
            self.temporary = {i: None for i in self.simpl_cipher.temporaries}
            for step in self.simpl_cipher.transition:
                if step[-1] == 'XOR':
                    self.apply_xor(step[0], step[1], step[2])
                elif step[-1] == 'AND':
                    self.apply_and(step[0], step[1], step[2])
                elif step[-1] == 'PERM':
                    self.apply_permutation(step[0])
                elif step[-1] == 'SBOX':
                    self.apply_sbox(step[0], step[1], step[2])
                elif step[-1] == 'MOV':
                    self.apply_mov(step[0], step[1])
                elif step[-1] == 'TEMP':
                    self.apply_add_temp(step[0])
            self.set_temporaries_to_zero()
            state = [self.state_bit[i] for i in range(self.state_size)]
            self.round_states.append(state)

            # zavedenie prvotnych podmienok
            normal_rnd_begin = 1
            self.temporary = {i: None for i in self.cipher.temporaries}

        for rnd in range(normal_rnd_begin, self.rounds):
            for step in self.cipher.transition:
                if step[-1] == 'XOR':
                    self.apply_xor(step[0], step[1], step[2])
                elif step[-1] == 'AND':
                    self.apply_and(step[0], step[1], step[2])
                elif step[-1] == 'PERM':
                    self.apply_permutation(step[0])
                elif step[-1] == 'SBOX':
                    self.apply_sbox(step[0], step[1], step[2])
                elif step[-1] == 'MOV':
                    self.apply_mov(step[0], step[1])
                elif step[-1] == 'TEMP':
                    self.apply_add_temp(step[0])
            self.set_temporaries_to_zero()
            state = [self.state_bit[i] for i in range(self.state_size)]
            self.round_states.append(state)
        self.fresh_conditions = True
        return

    def extract_mcv(self, rnd, active):
        self.create_conditions()
        active_bits = map(int, active)
        for active_bit in active_bits:
            if active_bit >= self.state_size or active_bit < 0:
                print('Bit {} designated as active bit, but there are only '
                      '{} state bits.'.format(active_bit, self.state_size))
                raise ValueError
                sys.exit(1)
        conditions = []
        for i in range(self.state_size):
            if i in active_bits:
                conditions.append(self.round_states[0][i])
            else:
                conditions.append(-self.round_states[0][i])
        reachable_mcv = []
        for mcv in chain.from_iterable(
                combinations(range(self.state_size), i) for i in range(1, self.state_size + 1)):
            # print("Testing {}".format(mcv))
            # check if not superset of some already found reachable_mcv
            if not all([not set(mcv).issuperset(x) for x in reachable_mcv]):
                continue
            mcvconditions = []
            for i in range(self.state_size):
                if i not in mcv:
                    mcvconditions.append(-self.round_states[rnd][i])
                else:
                    mcvconditions.append(self.round_states[rnd][i])
            reachable, _ = self.solver.solve(conditions + mcvconditions)
            if reachable:
                reachable_mcv.append(mcv)
        return reachable_mcv