from nadavca.read import Read
import subprocess
import numpy as np

def predict_by_basecall(read_file, alignment_file, gap_cost, mismatch_cost, max_period):
    read = Read.load_from_fast5(read_file, "Analyses/Basecall_1D_000")
    fastaname = 'tmp{}_{}.fasta'.format(gap_cost, mismatch_cost)
    with open(fastaname, 'w') as fasta:
        fasta.write('>temporary\n')
        fasta.write(''.join(read.sequence))
    
    separator_right = np.full(len(read.sequence), len(read.raw_signal), dtype=int)
    separator_left = np.full(len(read.sequence)+1, 0, dtype=int)
    
    for key, val in read.sequence_to_signal_mapping.items():
        separator_right[key] = val
        separator_left[key] = val
    
    
    for i, val in enumerate(separator_left):
        if val == 0 and i > 0:
            separator_left[i] = separator_left[i-1]
    
    for i in range(len(separator_right)-1, -1, -1):
        if separator_right[i] == len(read.raw_signal) and i + 1 < len(separator_right):
            separator_right[i] = separator_right[i+1]
    
    
    signal_map = np.zeros(len(read.raw_signal), dtype=float)
    repeats = subprocess.getoutput('tantan -f 4 -b {} -j {} -w {} {}'.format(gap_cost, mismatch_cost, max_period, fastaname))
    for line in repeats.split('\n'):
        tokens = line.split()
        if len(tokens) < 3:
            continue
        start, end = map(int, tokens[1:3])
        signal_map[separator_right[start] : separator_left[end]] = 1.0
    
    aligned_start, aligned_end = float('inf'), float('-inf')
    with open(alignment_file, 'r') as file:
        file.readline()
        file.readline()
        for line in file:
            ref_pos, start, end = map(int, line.split())
            aligned_start = min(aligned_start, start)
            aligned_end = max(aligned_end, end)
    
    return signal_map[aligned_start : aligned_end]
