#!/usr/bin/python3.8

import sys, os, glob
from utils import ref, fetch_reads, write_vcf, alphabet, filter_positions
import argparse


class Pileup:
    """Class for variant calling using pileup approach"""

    def __init__(self, reference, reads, absolute, relative, vcffile, coverage, ignore_deletions):
        self._reference = reference
        self._reads = reads
        self._relative = relative
        self._absolute = absolute
        self._vcffile = vcffile
        self._coverage = coverage
        self.ignore_deletions = ignore_deletions
        self._positions = filter_positions(self._coverage, self._relative, self._absolute, self._reference)
        self.find_variants()

    def create_sequences(self):
        high_support = lambda p, base: base in self._coverage[p].keys() and self._coverage[p][
            base] >= self._absolute and self._coverage[p][base] / self._coverage[p]['sum'] >= self._relative and base != \
                                       self._reference[p] and base != 'sum'
        positions = [(p, base) for p in self._positions for base in self._coverage[p] if high_support(p, base)]
        sequences = dict()
        for p, base1 in positions:
            value = False
            for base2 in alphabet:
                if (p - 1, base2) in sequences.keys():
                    sequences[(p, base1)] = sequences[(p - 1, base2)] + [(p, base1)]
                    sequences.pop((p - 1, base2))
                    value = True
            if not value:
                sequences[(p, base1)] = [(p, base1)]
        return [sequences[k] for k in sequences.keys()]

    def ref_alt(self, variant):
        alt_ = ''.join([base for p, base in variant]).replace('D', '')
        ref_ = ''.join([self._reference[p] for p, base in variant])
        if len(alt_) == 0:
            alt_ = self._reference[variant[0][0] - 1]
            ref_ = "{}{}".format(self._reference[variant[0][0] - 1], ref_)
        return ref_, alt_

    def variant_format(self, variant):
        chrom = self._reads[0].reference_name
        id_ = '.'
        ref_, alt = self.ref_alt(variant)
        pos = variant[0][0] - 1 if len(ref_) > len(alt) else variant[0][0]
        qual = 40
        supports = [self._coverage[p][base] for p, base in variant]
        qualities = [self._coverage[p][base] / self._coverage[p]['sum'] for p, base in variant]
        info = '{};{};{};{}'.format(str(min(supports)), str(sum(supports) / len(variant)), str(min(qualities)),
                                    str(sum(qualities) / len(variant)))
        return {'CHROM': chrom, 'POS': pos, 'ID': id_, 'REF': ref_, 'ALT': alt, 'QUAL': qual, 'INFO': info}

    def find_variants(self):
        is_deletion = lambda v: len(v['REF']) > len(v['ALT']) and not self.ignore_deletions
        formatted = [self.variant_format(s) for s in self.create_sequences()]
        filtered = [v for v in formatted if not is_deletion]
        write_vcf(self._vcffile, filtered, False)


def run_pileup(reference, reads, relative, absolute, vcffile, coverage=None, ignore_deletions=False):
    """Main function for running pileup algorithm"""
    Pileup(reference, reads, relative, absolute, vcffile, coverage, ignore_deletions).find_variants()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--reference', '-ref', help='Reference fasta file', metavar='in')
    parser.add_argument('--directory', '-dir', help='Read directory')
    parser.add_argument('--absolute', help='Number of supportive reads, necessary for variant to be reported',
                        default=100)
    parser.add_argument('--relative', help='Fraction of supportive reads, necessary fo variant to be reported',
                        default=0.3)
    parser.add_argument('--output', '-out', default='output.vcf')
    parser.add_argument('--ignore_deletions', default='False')
    args = parser.parse_args()

    reference = args.reference
    if not os.path.isfile(reference):
        sys.exit("File {} does not exist".format(reference))
    reference = ref(reference)

    reads_dir = args.directory
    if not os.path.isdir(reads_dir):
        sys.exit("Directory {} does not exist".format(reads_dir))

    ignore_deletions = args.ignore_deletions == 'True'

    files = glob.glob("{}/*.bam".format(reads_dir))

    run_pileup(reference=reference, reads=fetch_reads(files), relative=float(args.relative),
               absolute=int(args.absolute), vcffile=args.output, ignore_deletions=ignore_deletions)
