import sys
from solvatore import Solvatore
from pycryptosat import Solver
import logging


class SolvatoreBlue1(Solvatore):
    def __init__(self):
        '''
        Create a new Solvatore_blue instance
        '''
        super(SolvatoreBlue1, self).__init__()
        self.blue0 = []
        self.blue1 = []
        return

    def set_blue_bits(self, blue0, blue1):
        """
        Sets which bits should have which blue value
        """
        if self.cipher is None:
            print("No cipher loaded")
            sys.exit(1)
        if not all(self.state_size > x >= 0 for x in blue0):
            print("Some bit designated as blue0 is larger than {} "
                  "(state bits)".format(self.state_bit))
            sys.exit(1)
        if not all(self.state_size > x >= 0 for x in blue1):
            print("Some bit designated as blue1 is larger than {} "
                  "(state bits)".format(self.state_bit))
            sys.exit(1)
        if set(blue0).intersection(set(blue1)):
            print("Bits {} assigned both as blue0 and "
                  "blue1".format(set(blue0).intersection(set(blue1))))
        self.blue0 = blue0
        self.blue1 = blue1
        self.fresh_conditions = False

    def get_old_variable(self, bit):
        '''
        get_variables split into get_old_variable and get_new_variable
        only for checking if bit represents blue 0 (-2) or blue 1 (-3) in
        self.state_bit or self.temporary and does not create new variable
        '''
        if bit[0] == 's':
            number = int(bit[1:])
            old_bit = self.state_bit[number]
        elif bit[0] == 't' or bit[0] == 'b':
            old_bit = self.temporary[bit]
        else:
            if bit in self.sbox_tmps:
                old_bit = self.sbox_tmps[bit]
            else:
                old_bit = None
        return old_bit

    def get_new_variable(self, bit):
        '''
        get_variables split into get_old_variable and get_new_variable
        creates new variable and sets appropriate dictionaries same
        as the get_variables does
        '''
        new_bit = self.next_variable
        self.next_variable += 1
        if bit[0] == 's':
            number = int(bit[1:])
            self.state_bit[number] = new_bit
        elif bit[0] == 't' or bit[0] == 'b':
            self.temporary[bit] = new_bit
        else:
            self.sbox_tmps[bit] = new_bit
        return new_bit

    def set_new_variable(self, bit, new_bit):
        '''
        like get_new_variable, but it only sets given new variable
        '''
        if bit[0] == 's':
            number = int(bit[1:])
            self.state_bit[number] = new_bit
        elif bit[0] == 't' or bit[0] == 'b':
            self.temporary[bit] = new_bit
        else:
            self.sbox_tmps[bit] = new_bit
        return

    def create_conditions(self):
        '''
        Create conditions for solver from cipher description
        Only change against superclass is changing state bits
        to be blue
        '''
        if self.cipher == None:
            print('You need to load a cipher.')
            sys.exit(1)
        if self.rounds == None:
            print('You need to specify the number of rounds.')
            sys.exit(1)
        if self.fresh_conditions:
            return
        self.solver = Solver()
        self.next_variable = 1
        self.state_bit = [i + 1 for i in range(self.state_size)]
        state = [self.state_bit[i] for i in range(self.state_size)]
        self.round_states.append(state)
        self.next_variable += self.state_size

        # set certain bits as blue
        for i in self.blue0:
            self.state_bit[i] = -2
        for i in self.blue1:
            self.state_bit[i] = -3

        self.temporary = {i: None for i in self.cipher.temporaries}
        self.sbox_tmps = {}
        for rnd in range(self.rounds):
            for step in self.cipher.transition:
                if step[-1] == 'XOR':
                    self.apply_xor(step[0], step[1], step[2])
                elif step[-1] == 'AND':
                    self.apply_and(step[0], step[1], step[2])
                elif step[-1] == 'PERM':
                    self.apply_permutation(step[0])
                elif step[-1] == 'SBOX':
                    self.apply_sbox(step[0], step[1], step[2])
                elif step[-1] == 'MOV':
                    self.apply_mov(step[0], step[1])
                elif step[-1] == 'TEMP':
                    self.apply_add_temp(step[0])
            self.set_temporaries_to_zero()
            state = [self.state_bit[i] for i in range(self.state_size)]
            logging.debug(state)
            self.round_states.append(state)
        self.fresh_conditions = True
        logging.debug(self.round_states)
        return


    def apply_mov(self, target, source):
        """
        Apply operation MOV with regard to blue bits
        """
        old_source = self.get_old_variable(source)
        old_target = self.get_old_variable(target)
        # if no variable is blue, proceed as before
        if old_source > 0 and (old_target is None or old_target > 0):
            return super(SolvatoreBlue1, self).apply_mov(target, source)
        # if source is not blue, generate new variable with value 0
        # and proceed as before
        if old_source > 0:
            new_target = self.get_new_variable(target)
            self.addclause([-new_target])
            return super(SolvatoreBlue1, self).apply_mov(target, source)
        # nothing to be done
        if target == source:
            return
        # ensure old target was 0, if it were 1,
        # minimal choice vector does not exist, because it is overwritten
        if old_target != None and old_target > 0:
            self.addclause([-old_target])

        self.set_new_variable(target, old_source)
        return

    def apply_xor(self, target, source_1, source_2):
        """
        Apply operation XOR with regard to blue bits
        """
        old_source_1 = self.get_old_variable(source_1)
        old_source_2 = self.get_old_variable(source_2)

        # if no variable is blue, proceed as before
        if old_source_1 > 0 and old_source_2 > 0:
            return super(SolvatoreBlue1, self).apply_xor(target, source_1, source_2)

        if source_1 != target:
            self.apply_mov(target, source_2)
            source = source_1
        else:
            source = source_2

        old_source = self.get_old_variable(source)
        old_target = self.get_old_variable(target)
        # one of source_1 and  source_2 is not blue, therefore source or target is not blue

        # if source is blue, nothing happens to choice vector
        if old_source <= 0:
            # except when source is blue1, it inverts blue target
            if old_source == -3 and old_target <= 0:
                if old_target == -2:
                    new_target = -3
                elif old_target == -3:
                    new_target = -2
                else:
                    print("Blue bit {} is not 0 or 1".format(old_target))
                    raise ValueError
                self.set_new_variable(target, new_target)
            return

        # if target is blue, make it into new variable with value 0
        # and proceed as without blue bits
        new_target = self.get_new_variable(target)
        self.addclause([-new_target])

        old_source, new_source = self.get_variables(source)
        old_target, new_target = self.get_variables(target)
        self.addclause([-new_source, old_source])
        self.addclause([new_target, -old_target])
        self.addclause([-new_source, -new_target, old_target])
        self.addclause([new_source, new_target, -old_source])
        self.addclause([-new_target, old_source, old_target])
        self.addclause([new_source, -old_source, -old_target])
        return

    def apply_and(self, target, source_1, source_2):
        """
        First move source_2 to target, then apply the following
        rules on source_1, target:
        TODO
        """
        old_source_1 = self.get_old_variable(source_1)
        old_source_2 = self.get_old_variable(source_2)

        # if no variable is blue, proceed as before
        if old_source_1 > 0 and old_source_2 > 0:
            return super(SolvatoreBlue1, self).apply_and(target, source_1, source_2)

        if source_1 == source_2:
            self.apply_mov(target, source_1)
            return

        if source_1 != target:
            self.apply_mov(target, source_2)
            source = source_1
        else:
            source = source_2

        old_source = self.get_old_variable(source)
        old_target = self.get_old_variable(target)

        # if source is blue
        if old_source <= 0:
            # if target is not blue it has to be 0, because
            # this "target" function is zeroed-out by AND with blue 0
            if old_source == -2:
                if old_target > 0:
                    self.addclause([-old_target])
                self.set_new_variable(target, -2)
            elif old_source != -3:
                print("Blue bit {} is not 0 or 1".format(old_source))
                raise ValueError

            return

        if old_target == -2:
            self.set_new_variable(target, -2)
            return
        elif old_target != -3:
            print("Blue bit {} is not 0 or 1".format(old_target))
            print(old_source_1, old_source_2, old_source, old_target)
            raise ValueError
        # case where old_target is blue1, source is not blue
        # target behaves as if it was variable with value 0
        new_target = self.get_new_variable(target)
        self.addclause([-new_target])

        old_source, new_source = self.get_variables(source)
        old_target, new_target = self.get_variables(target)
        self.addclause([-new_source, old_source])
        self.addclause([new_target, -old_target])
        self.addclause([-new_source, -new_target])
        self.addclause([new_source, new_target, -old_source])
        self.addclause([-new_target, old_source, old_target])
        return

    def apply_add_temp(self, target):
        """
        Add new temporary variable
        """
        new_target = self.get_new_variable(target)
        self.addclause([-new_target])
        return

    def set_temporaries_to_zero(self):
        '''
        To guarantee that temporaries are zero at end of rounds
        '''
        for tmp in self.temporary.values():
            if tmp >= 0:
                self.addclause([-tmp])
        return

    def is_bit_balanced(self, bit, rnd, active):
        '''
        Only checks if active bit is not the same as blue
        '''
        if set(active).intersection(set(self.blue0) | set(self.blue1)):
            print("Bits {} designated both as active and blue"
                  "".format(set(active).intersection(set(self.blue0) | set(self.blue1))))
            raise ValueError
        # self.create_conditions()
        # TODO if in rnd are blue bits, mask them with new temporaries
        # for i in range(self.state_size):
        #     if self.round_states[rnd][i] <= 0:
        #         print("Trying to set {}, even though it is constant blue "
        #               "-- no need for computation".format(i))
        #         raise ValueError
        return super(SolvatoreBlue1, self).is_bit_balanced(bit, rnd, active)

    def apply_sbox(self, sbox_name, input_bits, output_bits):
        raise NotImplemented
