import sys
import os
from nadavca.read import Read
from scipy.special import softmax
import numpy as np
import signal_dtw
import time

if len(sys.argv) != 9:
    print("usage: {} reads_dir alignments_dir output_dir min_dist max_dist bonus max_speed_ratio window_size".format(sys.argv[0]))
    sys.exit(0)


reads_dir = sys.argv[1]
alignments_dir = sys.argv[2]
output_dir = sys.argv[3]
min_dist = int(sys.argv[4])
max_dist = int(sys.argv[5])
bonus = float(sys.argv[6])
max_speed_ratio = int(sys.argv[7])
window_size = int(sys.argv[8])

os.mkdir(output_dir)
total_time = 0
total_signals = 0

for read_filename in os.listdir(reads_dir):
    print("processing {}".format(read_filename))
    basename = os.path.splitext(read_filename)[0]
    other_filename = basename + '.txt'
    
    read = Read.load_from_fast5(os.path.join(reads_dir, read_filename), 'Analyses/Basecall_1D_000')
    Read.normalize_reads([read])
    
    
    aligned_start, aligned_end = float('inf'), float('-inf')
    alignment_path = os.path.join(alignments_dir, other_filename)
    with open(alignment_path, 'r') as alfile:
        alfile.readline()
        alfile.readline()
        for line in alfile:
            ref_position, event_start, event_end = map(int, line.split())
            aligned_start = min(aligned_start, event_start)
            aligned_end = max(aligned_end, event_end)
    
    signal = read.normalized_signal[aligned_start : aligned_end]
    
    length = len(signal)
    
    row_starts = [i + min_dist for i in range(length - min_dist)]
    row_ends = [min(i + max_dist + 1, length+1) for i in range(length - min_dist)]
    
    scorer = signal_dtw.SimpleScorer(signal, signal, bonus)
    movement_model = signal_dtw.MovementModel(max_speed_ratio, length)
    
    start = time.time()
    alignments = signal_dtw.local_alignment(length, length, scorer, movement_model, row_starts, row_ends, window_size)
    end = time.time()
    total_time += end - start
    total_signals += length
    sys.stderr.write("processed in {} (overall average {} signals per second)\n".format(end - start, total_signals / total_time))
    
    
    detected = np.zeros(length, dtype=float)
    
    for al in alignments:
        detected[al[0][0] : al[-1][-1]] = 1.0
    
    out_path = os.path.join(output_dir, other_filename)
    with open(out_path, 'w') as outfile:
        outfile.write('\n'.join(map(str, detected)))
