from cipher_description import CipherDescription
from sympy import *
import numpy as np
import logging
import sys

logging.basicConfig(stream=sys.stderr, level=logging.DEBUG)
def make_temp(temp_count):
    temp_count[0] += 1
    return 'tspec{:03d}'.format(temp_count[0])


def name_to_var(varname, revbkp):
    if varname[0] == 't':
        return varname
    elif varname[0] == 'c' or varname[0] == 'a':
        #     # print(varname, "s{}".format(int(varname[1:])))
        #     # return bkp[int(varname[1:])]
        return "s{}".format(int(varname[1:]))
    elif varname[0] == 's':
        return varname
    else:
        logging.error("zla premenna {}".format(varname))
        raise ValueError


def sbox_to_anf(sbox):
    s0, s1, s2, s3 = symbols("sbox0 sbox1 sbox2 sbox3")
    sboxvars = [s0, s1, s2, s3]
    const0, const1 = symbols("const0 const1")

    def generatemoebius(n):
        m2 = np.matrix(((1, 1), (0, 1)))
        compound = m2
        for i in range(n - 1):
            compound = np.kron(m2, compound)
        return compound

    s = np.array([map(int, "{:04b}".format(x)) for x in sbox]).T
    sboxmatrix = s.dot(generatemoebius(4)) % 2
    sboxline = [const1, s0, s1, s0 & s1, s2, s0 & s2, s1 & s2, s0 & s1 & s2, s3, s0 & s3, s1 & s3, s0 & s1 & s3,
                s2 & s3, s0 & s2 & s3, s1 & s2 & s3, s0 & s1 & s2 & s3]
    sboxfuncs = []
    for line in sboxmatrix.tolist():
        compoundfunc = const0
        for i in range(len(line)):
            if line[i] == 1:
                compoundfunc = compoundfunc ^ sboxline[i]
        sboxfuncs.append(compoundfunc)
        sboxfuncs = [x.subs([(const0, False),(const1, True)]) for x in sboxfuncs]
    return sboxfuncs, sboxvars[::-1]


# def flatten(expr):
#     if expr.func == Symbol:
#         return expr
#     elif expr.func == Not:
#         return Xor(flatten(expr.args[0]), const1)
#     elif expr.func == Xor:
#         all_subs = [flatten(x) for x in expr.args]
#         all_args = [list(x.args) if x.func == Xor else [x] for x in all_subs]
#         flat_args = [item for sublist in all_args for item in sublist]
#         output = Xor(*flat_args)
# #         print("Povodny {}".format(expr))
# #         print("Flat args {}".format(flat_args))
# #         print("output {}".format(output))
#         return output
#     elif expr.func == And:
#         all_subs = [flatten(x) for x in expr.args]
# #         print("Povodny {}".format(expr))
#         compound = flatten(all_subs[0])
# #         print("Flat args {}".format(all_subs))
# #         print("compound {}".format(compound))
#         for subf in all_subs[1:]:
#             compound = anfmul(compound, subf)
# #             print("\t subf {}".format(subf))
# #         print("\t compound {}".format(compound))
#         return compound
#     else:
#         print(expr)
#         raise ValueError
#
# def anfmul(expr1, expr2):
#     # predpoklad, ze obe premenne su v anf
#     multis = []
#     compound = []
#     for e in (expr1,expr2):
#         if e.func == Xor:
#             multis.append(e.args)
#         else:
#             multis.append([e])
#     for i in multis[0]:
#         for j in multis[1]:
#             compound.append(i&j)
#     return Xor(*compound)

def apply_anf(cipher, expression, origvars, temp_count):
    logging.debug("Aplikujem {}".format(expression))
    if expression.func == Symbol:
        # logging.debug("Davam symbol {}".format(name_to_var(expression, origvars)))
        logging.debug("output {} -> {}".format(expression, name_to_var(expression.name, origvars)))
        return name_to_var(expression.name, origvars)
    elif expression.func == And or expression.func == Xor:
        # logging.debug("Nasiel som {}".format(expression.func))
        if expression.func == And:
            compoundfunc = cipher.apply_and
        elif expression.func == Xor:
            compoundfunc = cipher.apply_xor
        else:
            logging.error("neexistujuca funkcia")
            raise OperationNotSupported
        compound = make_temp(temp_count)
        leftcomp = apply_anf(cipher, expression.args[0], origvars, temp_count)
        rightcomp = apply_anf(cipher, expression.args[1], origvars, temp_count)
        compoundfunc(leftcomp, rightcomp, compound)
        # print("Aplikujem {} na {}".format(expression.func, expression.args))
        for i in expression.args[2:]:
            compoundee = apply_anf(cipher, i, origvars, temp_count)
            compoundfunc(compoundee, compound, compound)
        logging.debug("output {} -> {}".format(expression, compound))
        return compound
    elif expression.func == Not:
        inside = apply_anf(cipher, expression.args[0], origvars, temp_count)
        logging.debug("output {} -> {}".format(expression, inside))
        return inside
    logging.error("Zloba")
    logging.error(expression.func)
    logging.error(type(expression.func))
    raise ValueError


# building anf
def anf_round(ot, key, size, lshift, ssum):
    # s = ot
    # prixorovanie kluca
    s = [Xor(*x) for x in zip(ot, key)]
    # sbox na anf
    sbox = [(x+ssum)%16 for x in range(16)]
    sboxfuncs, sboxvars = sbox_to_anf(sbox)
    s = [x.subs(zip(sboxvars,s)) for x in sboxfuncs]

    # for i in range(len(s)):
    #     #         print("splostujem sbox bit {}".format(i))
    #     s[i] = flatten(s[i])
    #
    # # sbox, ekvivalentny +5 mod 16
    # z = [None] * size
    # y = [S.false, S.true, S.false, S.true]
    # carry = S.false
    # for i in range(size - 1, -1, -1):
    #     # for i in range(n):
    #     # print('{}. kolo scitania'.format(i))
    #     if i != size - 1:
    #         carry = s[i + 1] & y[i + 1] ^ (s[i + 1] ^ y[i + 1]) & carry
    #         # carry = anfmul(x[i+1],y[i+1])^anfmul((x[i+1]^y[i+1]),carry)
    #     z[i] = s[i] ^ y[i] ^ carry
    # s = z

    # lshift
    s = s[lshift:] + s[:lshift]
    return s


def anf(ot, key, size, rounds, lshift, ssum):
    state = ot
    for rnd in range(rounds):
        state = anf_round(state, key, size, lshift, ssum)
    return state


def generate_RES(size, blue0, blue1 , lshift, ssum):
    specRES = CipherDescription(size)
    bits = ["s{}".format(x) for x in range(size)][::-1]
    symbol_bits = symbols(bits)
    # counter to ensure uniqueness of temporary variables
    temp_count = [0]

    # backup of input variables into temporary variables
    rev_bkp = {}
    bkp = [make_temp(temp_count) for _ in range(size)]
    for i, s in enumerate(bits):
        newbkpvar = bkp[i]
        rev_bkp[newbkpvar] = s
        specRES.apply_mov(s, newbkpvar)
    # making keys as new temporary variables, apply_xor with two identical sources just makes new target constant
    key = ["tkey{}".format(x) for x in range(size)]
    for k in key:
        specRES.apply_xor(bkp[0], bkp[0], k)
    # for i in range(size):
    #     specRES.apply_mov(bits[i], bkp[i])
    symbols_bkp = symbols(bkp)
    # pocitanie anf
    pokole = anf(symbols_bkp, key, size, 1, lshift=lshift, ssum=ssum)

    # dosadzanie konstant za modre bity a.k.a. zjednodusenie funkcie
    logging.debug(pokole)
    substitutions = [(symbols_bkp[3-x], False) for x in blue0]
    substitutions += [(symbols_bkp[3-x], True) for x in blue1]
    for i, bitfunc in enumerate(pokole):
        pokole[i] = bitfunc.subs(substitutions)
    logging.debug("Po substitucii")
    logging.debug(pokole)
    # transforming function from sympy to solvatore logic
    outtemps = []
    for x in pokole:
        logging.debug("Spracovavam output {}".format(x))
        outtemps.append(apply_anf(specRES, x, rev_bkp, temp_count))
    logging.debug(outtemps)
    # transfer of temporary results back into input bits
    for i, out in enumerate(outtemps):
        specRES.apply_mov(out, bits[i])
        logging.debug("Mov {} -> {}".format(out, bits[i]))
    return specRES
