import sys, os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')))
from library.instance import *
from md5_test import digest_to_hex

#print(md5(b"\x00", rounds=1))

instance = Instance()

# Everything is little endian
def intToVector(x, size=32):
    bits = [False]*size
    i = 0
    while x > 0:
        bits[i] = (x % 2 == 1)
        i += 1
        x //= 2
    return ConstantVector(bits)

#Kvec = [intToVector(x) for x in K]

# Original message length in bits
#mlength = 8*8
mlength = 8

###################################################

# For now just single block/chunk of 64bytes,
# 14 blocks for data + padding, 2 blocks for length
# Total of 16 blocks
Mvec = [BitVector(32) for _ in range(14)]

# Padding - "1", then "0" until 448 bits
for i in range(mlength, 448):
    Mvec[i // 32].bits[i % 32] = True if i == (mlength + 7) else False
# Length in the last 64 bits, little endian
Mvec.append(intToVector(mlength % (2**32)))
Mvec.append(intToVector(mlength // (2**32)))

# Number of rounds, full MD4 is 48
rounds = 48

a0, b0, c0, d0 = [intToVector(x) for x in [0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476]]
A, B, C, D = a0, b0, c0, d0

for i in range(rounds):
    f = [lambda a, b, c, d, k, s: CyclicLeftShift(a + (b&c | ~b&d) + Mvec[k], s),
         lambda a, b, c, d, k, s: CyclicLeftShift(a + (b&c | b&d | c&d) + Mvec[k] + intToVector(0x5A827999), s),
         lambda a, b, c, d, k, s: CyclicLeftShift(a + (b^c^d) + Mvec[k] + intToVector(0x6ED9EBA1), s)]
    k = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
         [0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15],
         [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]][i//16][i % 16]
    s = [[3, 7, 11, 19], [3, 5, 9, 13], [3, 9, 11, 15]][i//16][i%4]
    if i % 4 == 0:
        A = f[i//16](A, B, C, D, k, s)
    elif i % 4 == 1:
        D = f[i//16](D, A, B, C, k, s)
    elif i % 4 == 2:
        C = f[i//16](C, D, A, B, k, s)
    else:
        B = f[i//16](B, C, D, A, k, s)

    #F = fs[i//16](A, B, C, D)
    #G = gs[i//16](i)

    #X = A + F + Kvec[i] + Mvec[G]
    #R = CyclicLeftShift(X, S[i])
    pass
a0, b0, c0, d0 = a0+A, b0+B, c0+C, d0+D

###################################################

# Fix message/output bits here
#Mvec[0].bits = [True]*32
#a0.bits = [True]*8 + [None]*24

###################################################

# TODO prettier
# Generate CNF instance, solve, read
print('Emit start')
instance.emit([a0, b0, c0, d0] + Mvec)
from subprocess import call
call(['minisat', 'instance.cnf', 'instance.out'])
instance.read('instance.out')

# Little-endian bit list to int
def toInt(bits):
    val = 0
    for b in bits[::-1]:
        val = val*2 + (1 if b else 0)
    return val

# Get message bits
Mbits = []
for i in range(mlength):
    Mbits.append(Mvec[i // 32].getValuation(instance)[i % 32])
print('Message length', mlength, 'bits') #, Mbits)
# Get digest bits
Dbits = []
for q in [a0, b0, c0, d0]:
    Dbits += q.getValuation(instance)
print('Digest =', toInt(Dbits)) #, Dbits)
print('Digest', digest_to_hex(toInt(Dbits))) #, Dbits)

# Assume 8bit multiple length, generate message
# Then test with reference implementation for match
message = b""
for i in range(0, mlength//8):
    bits = Mvec[i//4].getValuation(instance)[(i%4)*8 : (i%4)*8 + 8]
    message += toInt(bits).to_bytes(1, byteorder='little')
print ('Message bytes:', message, 'rounds: ', rounds)

#reference = MD4(rounds=rounds)
#reference.update(message)
#print('MD4   ', reference.digest())
#assert reference.digest() == digest_to_hex(toInt(Dbits))
#print('MATCH!')