import base64
import time

import certstruct
import requests
import typer

from cryptography import x509
from cryptography.hazmat._oid import NameOID, ExtensionOID
from cryptography.hazmat.backends import default_backend
from datetime import datetime
from typing import List, Tuple

from ct_logs import CT_LOGS

GET_STH_SUFFIX = "/ct/v1/get-sth"
GET_ENTRIES_SUFFIX = "/ct/v1/get-entries?start={}&end={}"


def base32_safe_decode(base32_url: bytes) -> bytes:
    padding = b"=" * ((8 - (len(base32_url) % 8)) % 8)
    return base64.b32decode(base32_url + padding)


def get_tree_size(log_url: str):
    r = requests.get(f"{log_url}{GET_STH_SUFFIX}")
    if r.status_code == 200:
        return int(r.json()["tree_size"])
    else:
        raise Exception("Failed in retrieving CT log info.")


def get_max_block_size(log_url: str):
    r = requests.get(f"{log_url}{GET_ENTRIES_SUFFIX.format(0, 2048)}")
    if r.status_code == 200:
        return len(r.json()["entries"])
    else:
        raise Exception("Failed in retrieving CT log info.")


def get_log_entries(log_url: str, start, max_block_size):
    r = requests.get(
        f"{log_url}{GET_ENTRIES_SUFFIX.format(start, start + max_block_size)}"
    )
    if r.status_code == 200:
        return r.json()["entries"]
    else:
        raise Exception("Failed in retrieving CT log info.")


def get_all_domains(cert) -> list:
    all_domains = set()

    try:
        attrs = cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
        for attr in attrs:
            all_domains.add(attr.value)
    except:
        pass

    try:
        ext = cert.extensions.get_extension_for_oid(
            ExtensionOID.SUBJECT_ALTERNATIVE_NAME
        )
        alt_names = ext.value.get_values_for_type(x509.DNSName)
        all_domains.update(alt_names)
    except:
        pass

    return list(all_domains)


def parse_cert(entry):
    cert_data = certstruct.CertData()

    leaf_input = entry["leaf_input"]
    mth = certstruct.MerkleTreeHeader.parse(base64.b64decode(leaf_input))

    cert_data.log_timestamp = mth.Timestamp
    if mth.LogEntryType == "X509LogEntryType":
        cert_data.type = "cert"
        try:
            cert = x509.load_der_x509_certificate(
                certstruct.CertEntry.parse(mth.Entry).CertData, default_backend()
            )
            cert_data.all_domains = get_all_domains(cert)
        except:
            return None
    else:
        cert_data.type = "precert"
    return cert_data


def get_certs(
    log_url: str, id_from: int, id_to: int, max_block_size: int
) -> Tuple[int, list, int]:
    certs = []
    number_requests = 0
    while True:
        log_entries = get_log_entries(log_url, id_from, max_block_size)
        number_requests += 1
        for i, entry in enumerate(log_entries):
            cert_data = parse_cert(entry)
            if cert_data is None:
                continue
            if cert_data.type == "cert":
                certs.append(
                    {
                        "id": i + id_from,
                        "log_timestamp": cert_data.log_timestamp,
                        "domains": ",".join(cert_data.all_domains),
                    }
                )
        id_from += len(log_entries) + 1
        if id_from > id_to:
            break
    return id_from, certs, number_requests


def find_cert(log_url: str, log_name: str, subject_name: str) -> dict:
    tree_size = get_tree_size(log_url)
    max_block_size = get_max_block_size(log_url)

    current_index = 593495948
    time.sleep(3)
    print(f"Looking for certificate inside {log_name} from index {current_index}.\n")
    all_requests = 0
    while True:
        next_index, certs, number_requests = get_certs(
            log_url,
            current_index,
            current_index + max_block_size,
            max_block_size,
        )
        all_requests += number_requests
        for cert in certs:
            print(
                f"{cert['id']} | {datetime.utcfromtimestamp(int(cert['log_timestamp'] / 1000))} | {cert['domains']}"
            )
            if subject_name in cert["domains"]:
                print(
                    f"\nNumber of requests needed to find the certificate: {all_requests}"
                )
                return cert
        current_index = next_index
        # Give up
        if current_index - tree_size > 20000:
            return {}


def decode_message_from_cert(sans: List[str], subject_name: str) -> str:
    first_subject_name_label = subject_name.split(".")[0]
    encoded_message = str()
    sans.sort()

    for san in sans:
        labels = san.split(".")

        for i, label in enumerate(labels):
            if label == first_subject_name_label:
                break
            start = 1
            if i == 0:
                start += 1
            encoded_message += label[start:]

    decoded_messages = base32_safe_decode(
        encoded_message.upper().encode("utf-8")
    ).decode("utf-8")
    return decoded_messages


def main(
    subject_name: str = typer.Argument(
        help="Subject Name (SN) of the wanted certificate, e.g. example.com.",
        default="ctlr.seclab.dcs.fmph.uniba.sk",
    ),
    log_name: str = typer.Option(
        help="The CT log in which the certificate should be found.",
        default="argon2023",
    ),
):
    if log_name not in CT_LOGS.keys():
        print(f"Error: {log_name} is not a valid CT log name.")
        raise typer.Abort()

    log_url = CT_LOGS[log_name]
    cert = find_cert(log_url, log_name, subject_name)
    if not cert:
        print("Failed to find the message :(")
        typer.Abort()
    sans = cert["domains"].split(",")

    message = decode_message_from_cert(sans, subject_name)
    pretty_message = typer.style(message, fg=typer.colors.GREEN, bold=True)
    typer.echo(
        f"Successfully found the message sent through the CT channel: {pretty_message}"
    )


if __name__ == "__main__":
    typer.run(main)
