from sympy import *
from cipher_description import CipherDescription
from ciphers import speck
import logging
import numpy as np
from math import log


logging.basicConfig(level=logging.ERROR)
class CipherDescriptionBlue2(CipherDescription):

    def __init__(self, orig_cipher, blue0=[], blue1=[], rounds=1, anf=False):
        """
        Creates boolean functions based on CipherDescription orig_cipher

        @type orig_cipher: CipherDescription
        """
        # creating boolean functions of a cipher
        self.orig_cipher = orig_cipher
        self.state_size = orig_cipher.state_size
        self.init_state = ['s{}'.format(i) for i in range(orig_cipher.state_size)][::-1]
        self.init_state_sym = symbols(self.init_state)
        state = self.init_state_sym
        variables = {self.init_state[i]: state[i] for i in range(self.state_size)}
        variables.update({x: None for x in orig_cipher.temporaries})
        # create sboxes
        self.sboxes_bool = dict()
        self.sboxes = dict()
        for sb in orig_cipher.sboxes:
            sbtab = orig_cipher.sboxes[sb]
            logging.info("Creating sbox {}".format(sb))
            # self.add_sbox(sb, sbtab)
            self.sboxes_bool[sb] = self.sbox_to_anf(sbtab)

        logging.info("Creating boolean functions")
        for i in range(rounds):
            for step in orig_cipher.transition:
                if step[-1] == 'XOR':
                    variables[step[0]] = Xor(variables[step[1]], variables[step[2]])
                elif step[-1] == 'AND':
                    if anf:
                        variables[step[0]] = self.anfmul(variables[step[1]], variables[step[2]])
                    else:
                        variables[step[0]] = And(variables[step[1]], variables[step[2]])
                elif step[-1] == 'PERM':
                    permutation = step[0]
                    tmp = variables[permutation[-1]]
                    for bit in permutation:
                        variables[bit], tmp = tmp, variables[bit]
                elif step[-1] == 'SBOX':
                    sbfunc, sbvars = self.sboxes_bool[step[0]]
                    input_bits = step[1]
                    output_bits = step[2]
                    # substituting input variables into functions
                    input_funcs = [variables[x] for x in input_bits]
                    subs = zip(sbvars, input_funcs)
                    sub_funcs = [x.subs(subs) for x in sbfunc]
                    if anf:
                        for i in range(len(sub_funcs)):
                            sub_funcs[i] = self.flatten(sub_funcs[i])
                    # moving variables into output bits
                    for j, o in enumerate(output_bits):
                        variables[o] = sub_funcs[j]
                elif step[-1] == 'MOV':
                    variables[step[0]] = variables[step[1]]
                elif step[-1] == 'TEMP':
                    # TODO: fix so the keys are different for every round
                    # variables[step[0]] = symbols(step[0]+"round{}".format(i))
                    variables[step[0]] = symbols(step[0])
                # logging.debug(step[-1] +"\n" + "\n".join([str(variables[x]) for x in ["s3", "s2", "s1", "s0"]]))
        self.bool_functions = [variables[x] for x in self.init_state]
        if anf:
            const1 = symbols("const1")
            self.bool_functions = [x.subs([(const1, True)]) for x in self.bool_functions]

        logging.debug(self.bool_functions)
        # print(self.bool_functions)
        self.set_blue_bits(blue0, blue1)
        # print(self.temporaries)
        return

    def set_blue_bits(self, blue0, blue1):
        """
        Based on given blue bits, simplify self.bool_functions and create
        CipherDescription transformations based on that

        @type blue0: list
        @type blue1: list
        """
        logging.info("Simplifying boolean functions")
        subs = [(self.init_state_sym[::-1][i], False) for i in blue0]
        subs += [(self.init_state_sym[::-1][i], True) for i in blue1]
        functions_sub = []
        for func in self.bool_functions:
            functions_sub.append(func.subs(subs))
        logging.debug("After substitution")
        logging.debug(functions_sub)

        # creating CipherDescription based on simplified functions
        # sboxes = self.sboxes
        super(CipherDescriptionBlue2, self).__init__(self.orig_cipher.state_size)
        self.temp_count = 0
        # self.sboxes = sboxes
        outtemps = []
        for i, x in enumerate(functions_sub):
            logging.info("Creating CipherDescription for {}.bit".format(i))
            outtemps.append(self.apply_function(x))
        logging.debug(outtemps)
        # transfer of temporary results back into input bits
        for i, out in enumerate(outtemps):
            self.apply_mov(out, self.init_state[i])
            logging.debug("Mov {} -> {}".format(out, self.init_state[i]))
        return

    def anfmul(self, expr1, expr2):
        """
        Multiplication that conserves ANF of final expression
        """
        multis = []
        compound = []
        for e in (expr1, expr2):
            if e.func == Xor:
                multis.append(e.args)
            else:
                multis.append([e])
        for i in multis[0]:
            for j in multis[1]:
                compound.append(i & j)
        return Xor(*compound)

    def flatten(self, expr):
        """
        Function to make expr into ANF
        """
        if expr.func == Symbol:
            return expr
        elif expr.func == Not:
            const1 = symbols("const1")
            return Xor(self.flatten(expr.args[0]), const1)
        elif expr.func == Xor:
            all_subs = [self.flatten(x) for x in expr.args]
            all_args = [list(x.args) if x.func == Xor else [x] for x in all_subs]
            flat_args = [item for sublist in all_args for item in sublist]
            output = Xor(*flat_args)
            #         print("Povodny {}".format(expr))
            #         print("Flat args {}".format(flat_args))
            #         print("output {}".format(output))
            return output
        elif expr.func == And:
            all_subs = [self.flatten(x) for x in expr.args]
            #         print("Povodny {}".format(expr))
            compound = self.flatten(all_subs[0])
            #         print("Flat args {}".format(all_subs))
            #         print("compound {}".format(compound))
            for subf in all_subs[1:]:
                compound = self.anfmul(compound, subf)
            #             print("\t subf {}".format(subf))
            #         print("\t compound {}".format(compound))
            return compound
        else:
            print("neznama funkcia")
            print(expr)

    def get_anf(self, sbox):
        "Generate sbox from list of values, same as solvatore.get_anf"
        n = int(log(len(sbox), 2))
        anf = [x for x in sbox]
        for i in range(n):
            mask = (1 << i)
            for j in range(len(anf)):
                if j & mask:
                    anf[j] ^= anf[j ^ mask]
        return anf

    def sbox_to_anf(self, sbox):
        sblen = int(log(len(sbox), 2))
        sbvars = symbols(["sbox{}".format(x) for x in range(sblen)])
        sbmatrix = self.get_anf(sbox)
        sbfuncs = [False] * sblen
        for monom in range(2 ** sblen):
            # creating monomial
            monomfunc = True
            monindex = bin(monom)[2:][::-1]
            for i, t in enumerate(monindex):
                if t == '1':
                    monomfunc = And(monomfunc, sbvars[i])
            # checking fot which bits should the monomial be added
            for row in range(sblen):
                if (1 << row) & sbmatrix[monom]:
                    sbfuncs[sblen - row - 1] = Xor(sbfuncs[sblen - row - 1], monomfunc)
        return sbfuncs[::-1], sbvars

    def apply_function(self, expression):
        logging.debug("Applying {}".format(expression))
        if expression.func == Symbol:
            if expression.name[0] == 't':
                if not expression.name.startswith("tkey"):
                    logging.error("Unknown variable {}".format(expression.name))
                if expression.name not in self.temporaries:
                    self.apply_add_temp(expression.name)
                return expression.name
            elif expression.name[0] == 's':
                new_var = self.make_temp()
                self.apply_mov(expression.name, new_var)
                logging.debug("output {} -> {}".format(expression, new_var))
                return new_var
            else:
                logging.error("Unknown variable {}".format(expression.name))
        elif expression.func == And or expression.func == Xor:
            # logging.debug("Nasiel som {}".format(expression.func))
            if expression.func == And:
                compoundfunc = self.apply_and
            elif expression.func == Xor:
                compoundfunc = self.apply_xor
            else:
                logging.error("Not implemented operation")
                raise OperationNotSupported
            compound = self.make_temp()
            leftcomp = self.apply_function(expression.args[0])
            rightcomp = self.apply_function(expression.args[1])
            compoundfunc(leftcomp, rightcomp, compound)
            # print("Aplikujem {} na {}".format(expression.func, expression.args))
            for f in expression.args[2:]:
                compoundee = self.apply_function(f)
                compoundfunc(compoundee, compound, compound)
            logging.debug("output {} -> {}".format(expression, compound))
            return compound
        elif expression.func == Not:
            inside = self.apply_function(expression.args[0])
            logging.debug("output {} -> {}".format(expression, inside))
            return inside
        elif expression.func == boolalg.BooleanFalse or expression.func == boolalg.BooleanTrue:
            newvar = self.make_temp()
            self.apply_add_temp(newvar)
            return newvar
        logging.error("Unknown function")
        logging.error(expression.func)
        logging.error(type(expression.func))
        raise ValueError

    def make_temp(self):
        self.temp_count += 1
        return 'tspec{:03d}'.format(self.temp_count)
