#!/usr/bin/python3.8

import numpy as np
import random
import deepnano2
from ont_fast5_api.fast5_interface import get_fast5_file
import glob, argparse, os, sys
from utils import align_fasta
from fastaparser import Writer, FastaSequence

def med_mad(x, factor=1.4826):
    """
    Calculate signal median and median absolute deviation
    """
    med = np.median(x)
    mad = np.median(np.absolute(x - med)) * factor
    return med, mad


def rescale_signal(signal):
    signal = signal.astype(np.float32)
    med, mad = med_mad(signal)
    signal -= med
    signal /= mad
    return signal

letters = "NACGT"

def read_probs(filename):
    with open(filename) as file:
        data = file.read()
    data = data.split("\n")
    data = data[:-1]
    data = [[float(x) for x in l.split(",")] for l in data]
    return data


def prefix_sums(probs):
    def p(x):
        result = []
        s = 0
        for i in range(5):
            s += x[i]
            result.append(s)
        return result

    return [p(x) for x in probs]


def sample(prefix):
    string = ""
    for p in prefix:
        x = random.random()
        index = 0
        for i in range(5):
            if x < p[i]:
                index = i
                break
        string += letters[index]
    return string


def duplicates(s):
    if s == "":
        return ""
    string = s[0]
    for i in s:
        if i != string[-1]:
            string += i
    return string

def shorten(s):
    return "".join([duplicates(i) for i in s.split("N")])


def setup_caller():
    network_type = "96"
    beam_size = 1
    beam_cut_threshold = 0.01
    weights = os.path.join(deepnano2.__path__[0], "weights", "rnn%s.txt" % network_type)
    caller = deepnano2.Caller(network_type, weights, beam_size, beam_cut_threshold)
    return caller


def mapping(mapping_file):
    with open(mapping_file) as f:
        lines = [l.split("\t") for l in f.readlines()][1:]
        return {l[0]: l[1][:-1] for l in lines}


def generate_probs(fast5_dir, probs_dir, mapping_file):
    mapping_ = mapping(mapping_file)
    caller = setup_caller()
    for r in mapping_.keys():
        with get_fast5_file("{}/{}".format(fast5_dir, mapping_[r])) as f5:
            signal = rescale_signal(f5.get_read(r).get_raw_data())
            caller.compute_save(signal, "{}/{}.txt".format(probs_dir, r))


def generate_samples(output, probs_dir, n):
    reads = glob.glob("{}/*".format(probs_dir))
    fasta = open(output, "w")
    writer = Writer(fasta)
    for r in reads:
        read_id = r.split('.')[0]
        prefix = prefix_sums(read_probs("{}".format(r)))
        for i in range(n):
            writer.writefasta(FastaSequence(id_='{}-{}'.format(read_id, str(i)), sequence=shorten(sample(prefix))))
    fasta.close()

def generate_reads(fast5_dir, output, mapping_file):
    mapping_ = mapping(mapping_file)
    caller = setup_caller()
    fasta = open(output, "w")
    for r in mapping_.keys():
        with get_fast5_file("{}/{}".format(fast5_dir, mapping_[r])) as f5:
            fasta.write(">{}\n{}\n\n".format(r, caller.call_raw_signal(rescale_signal(f5.get_read(r).get_raw_data()))))
    fasta.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--fast5dir', '-f5', help='Read directory', required=True)
    parser.add_argument('--probsdir', help='Directory with files containing probabilities.', default='probs')
    parser.add_argument('--skip-probs', help='If False, probabilities will be generated in probsdir.', default=False)
    parser.add_argument('--mapping', '-map', help='Input file containing mapping reads to fas5 files',
                        default='files.txt')
    parser.add_argument('--reference', '-ref', help='Reference fasta file')

    parser.add_argument('--number', '-n', help='Number of samples to generate for one read', default=10)
    parser.add_argument('--sample', '-s', help='Output file for sampling')
    parser.add_argument('--generate-reads', help='Output file for reads generation')
    args = parser.parse_args()

    fast5 = args.fast5dir
    if not os.path.isdir(fast5):
        sys.exit("Directory {} does not exist".format(fast5))

    probs_dir = args.probsdir
    if not os.path.isdir(probs_dir):
        os.mkdir(probs_dir)

    files_mapping = args.mapping
    if not os.path.isfile(files_mapping):
        sys.exit("File {} does not exist".format(files_mapping))

    n = int(args.number)

    generate_probs_ = not bool(args.skip_probs)
    generate_samples_ = args.sample
    generate_reads_ = args.generate_reads

    reference = args.reference
    if reference is not None:
        if not os.path.isfile(args.reference):
            sys.exit("File {} does not exist".format(reference))

    if generate_probs_:
        generate_probs(fast5_dir=fast5, probs_dir=probs_dir, mapping_file=files_mapping)

    if not generate_samples_ is None:
        generate_samples(probs_dir=probs_dir, output=generate_samples_, n=n)
        if reference is not None:
            align_fasta(fastas=[generate_samples_], reference=reference)

    if not generate_reads_ is None:
        generate_reads(mapping_file=files_mapping, output=generate_reads_, fast5_dir=fast5)
        if reference is not None:
            align_fasta(fastas=[generate_reads_], reference=reference)


