import argparse
import sys
import os
from scipy.special import softmax
import numpy as np
import yaml
import hornet_hmm
import time
from peak_filter import peak_filter



def build_model(yaml_path):
    with open(yaml_path, 'r') as file:
        config = yaml.load(file)
    min_period = config['min_period']
    max_period = config['max_period']
    decay = config['decay']
    p_nucleotide_stay_background = config['p_nucleotide_stay_background']    
    p_nucleotide_stay_repeat = config['p_nucleotide_stay_repeat']
    p_nucleotide_is_indel = config['p_nucleotide_is_indel']
    match_matrix = config['match_matrix']
    priors = config['priors']

    if config['type'] == 'simple':
        model = hornet_hmm.SimpleModel(min_period,
                                       max_period,
                                       decay,
                                       priors,
                                       p_nucleotide_stay_background,
                                       p_nucleotide_stay_repeat,
                                       p_nucleotide_is_indel,
                                       match_matrix)
        return model, priors
    elif config['type'] == 'full':
        p_linger_nucleotide = config['p_linger_nucleotide']
        p_linger_blank = config['p_linger_blank']
        nucleotide_distribution = config['nucleotide_distribution']
        model = hornet_hmm.FullModel(min_period, 
                                     max_period, 
                                     decay, 
                                     p_nucleotide_stay_background, 
                                     p_nucleotide_stay_repeat, 
                                     p_nucleotide_is_indel, 
                                     p_linger_nucleotide, 
                                     p_linger_blank, 
                                     nucleotide_distribution, 
                                     match_matrix)
        return model, priors
    else:
        sys.err.write("Unknown model type: {}; use one of {'simple'}\n")
        sys.exit(1)



parser = argparse.ArgumentParser()

parser.add_argument("input", help="CTC predictions (.signal) for a read  or a directory containing read CTC predictions")
parser.add_argument("-o", "--output", help="output file or directory (defaults to stdout)")
parser.add_argument("-m", "--model", 
                    help="a YAML specifying model to be used",
                    default=os.path.join("models", "simple.yaml"))
parser.add_argument("-d", "--decoder",
                    default="forward-backward")
parser.add_argument("-p", "--preprocess", 
                    help="preprocess CTC predictions with peak filter", 
                    action="store_true")

args = parser.parse_args()

model, priors = build_model(args.model)

decoder = None
if args.decoder == 'forward-backward':
    decoder = hornet_hmm.ForwardBackwardDecoder()
elif args.decoder == 'viterbi':
    decoder = hornet_hmm.ViterbiDecoder()
else:
    sys.err.write("Unknown decoder option: {}; use one of {'forward-backward', 'viterbi'}\n")
    sys.exit(1)

input_paths = []
output_paths = []

if os.path.isdir(args.input):
    if args.output:
        os.mkdir(args.output)
    for filename in os.listdir(args.input):
        basename, extension = os.path.splitext(filename)
        if extension == '.signal':
            input_paths.append(os.path.join(args.input, filename))
            if args.output:
                out_file = os.path.join(args.output, basename + '.txt')
            else:
                out_file = None
            output_paths.append(out_file)
else:
    input_paths.append(args.input)
    if args.output:
        output_paths.append(args.output)
    else:
        output_paths.append(None)


total_time = 0
total_signals = 0
for input_path, output_path in zip(input_paths, output_paths):
    sys.stderr.write("Processing {}\n".format(input_path))
    sys.stderr.write("Loading...\r")
    logits = []
    with open(input_path, 'r') as in_file:
        for line in in_file:
            logits.append(list(map(float, line.split())))
    logits = np.array(logits)
    probs = softmax(logits, axis=1)
    
    if args.preprocess:
        probs = peak_filter(probs)
    
    sys.stderr.write("Decoding HMM...\n")
    start = time.time()
    result = decoder.Decode(probs, priors, model)
    end = time.time()
    total_time += end - start
    total_signals += len(probs)
    sys.stderr.write("Done in {} seconds (overall average {} signals per second)\n".format(end - start, total_signals / total_time))
    
    
    sys.stderr.write("Saving output...\r")
    if output_path is None:
        out_file = sys.stdout
    else:
        out_file = open(output_path, 'w')
    
    for prediction in result:
        out_file.write('{}\n'.format(prediction))
    
    if output_path is not None:
        out_file.close()
