#!/usr/bin/python3.8

import argparse
import glob
import os
import sys
from utils import *


class Freebayes:
    def __init__(self, reference, reads, padding, absolute, relative, vcffile, coverage):
        """Class for variant calling using pileup approach"""

        self._reads = reads
        self._reference = reference
        self._padding = padding
        self._absolute = absolute
        self._relative = relative
        self._coverage = coverage if coverage is not None else count_coverage(reads)
        self._suspicious_positions = filter_positions(coverage=self._coverage,
                                                      reference=self._reference, absolute=10, relative=0.1)
        self._count = {}
        self._vcffile = vcffile

    def begin_end_indexes(self, read):
        pairs = read.get_aligned_pairs()
        aligned = [i for i in range(len(pairs)) if pairs[i][1] is not None]
        return aligned[0], aligned[-1]

    def mapping(self, read):

        return {q: p for p, q in read.get_aligned_pairs() if p is not None and q is not None}

    def find_ranges(self, read):
        pairs = read.get_aligned_pairs()
        mapping = self.mapping(read)
        valid_positions = {q: p for p, q in pairs if
                           p is not None and q is not None and read.seq[p] == self._reference[q]}
        valid_qs = sorted([q for q in valid_positions.keys()])
        if len(valid_qs) == 0:
            return []
        original_ranges = [(p - self._padding, p + self._padding + 1) for p in self._suspicious_positions if
                           p in range(valid_qs[0], valid_qs[-1])]

        def find_range(low, high):
            if low < valid_qs[0]:
                low = valid_qs[0]
            if high > valid_qs[-1]:
                high = valid_qs[-1]
            while low not in valid_positions.keys():
                low -= 1
            while high not in valid_positions.keys():
                high += 1
            return low, high

        def join_close_ranges(ranges):
            joined = []
            for p, q in ranges:
                if len(joined) == 0:
                    joined.append((p, q))
                elif p in range(*joined[-1]) and q not in range(*joined[-1]):
                    joined[-1] = (joined[-1][0], q)
                else:
                    joined.append((p, q))
            return joined

        ranges = join_close_ranges([find_range(low, high) for low, high in original_ranges])
        return [(low, high, mapping[low], mapping[high]) for low, high in ranges]

    def process_read(self, read):
        mapped_ranges = self.find_ranges(read)
        for p1, p2, p3, p4 in mapped_ranges:
            if p3 == p4:
                alt = read.seq[p3]
            else:
                alt = read.seq[p3:p4]
            if (p1, p2) not in self._count.keys():
                self._count[(p1, p2)] = {}
            if alt not in self._count[(p1, p2)].keys():
                self._count[(p1, p2)][alt] = 0
            self._count[(p1, p2)][alt] += 1

    def variant_format(self, variant):
        low, high, alt, support = variant
        s = support
        q = support / max([self._coverage[p]['sum'] for p in range(low, high) if p in self._coverage.keys()])
        info = 'S={};Q={}'.format(str(s), str(q))
        return {'CHROM': self._reads[0].reference_name, 'POS': low, 'ID': '.', 'REF': self._reference[low:high], 'ALT': alt, 'QUAL': 40,
                'INFO': info}

    def valid_variant(self, variant):
        info = variant['INFO'].split(';')
        s = int(info[0].split('=')[1])
        q = float(info[1].split('=')[1])
        return variant['REF'] != variant['ALT'] and s >= self._absolute and q >= self._relative

    def find_variants(self):
        for r in self._reads:
            self.process_read(r)
        result = [self.variant_format((*p, alt, self._count[p][alt])) for p in self._count.keys() for alt in self._count[p].keys()]
        result = [r for r in result if self.valid_variant(r)]
        write_vcf(self._vcffile, result, False)


def run_freebayes(reference, reads, padding, relative, absolute, vcffile, coverage=None):
    """Main function for running freebayes algorithm"""
    Freebayes(reference=reference, reads=reads, padding=padding, relative=relative, absolute=absolute,
              vcffile=vcffile, coverage=coverage).find_variants()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--reference', '-ref', help='Reference fasta file', required=True, metavar='in')
    parser.add_argument('--directory', '-dir', help='Read directory', required=True)
    parser.add_argument('--absolute', help='Number of supportive reads, necessary for variant to be reported',
                        default=100)
    parser.add_argument('--relative', help='Fraction of supportive reads, necessary fo variant to be reported',
                        default=0.3)
    parser.add_argument('--output', '-out', default='output.vcf')
    parser.add_argument('--padding', help='Width of window', default=3)
    args = parser.parse_args()

    reference = args.reference
    if not os.path.isfile(args.reference):
        sys.exit("File {} does not exist".format(reference))
    reference = ref(reference)

    reads_dir = args.directory
    if not os.path.isdir(reads_dir):
        sys.exit("Directory {} does not exist".format(reads_dir))

    files = glob.glob("{}/*.bam".format(reads_dir))

    reads = fetch_reads(files)
    relative = float(args.relative)
    absolute = int(args.absolute)

    run_freebayes(reference=reference, reads=reads, padding=int(args.padding), relative=relative, absolute=absolute,
                  vcffile=args.output)
