
import pysam
import fastaparser, os

directory = "data-covid"
fasta_ = "reference.fasta"

alphabet = ["A", "C", "G", "T", "D", "I"]
short_alphabet = ["A", "C", "T", "G"]


def fetch_reads(filenames):
    reads = []
    for f in filenames:
        reads += [r for r in pysam.AlignmentFile(f, "r", ignore_truncation=True)]
    return reads


def ref(file):
    with open(file) as fasta_file:
        reader = fastaparser.Reader(fasta_file)
        return next(reader).sequence_as_string()


def count_coverage(reads):
    coverage = {}

    def count_coverage_read(read):
        aligned_pos = aligned_positions(read)

        def insertion_alternative(position):
            alt = read.seq[aligned_pos[i-1][0]]
            length = len(aligned_pos)
            while position < length and aligned_pos[position][1] is None:
                alt += read.seq[aligned_pos[position][0]]
                position += 1
            return alt

        lastq = 0

        for i in range(len(aligned_pos)):
            p, q = aligned_pos[i]
            if p is None:
                alt = "D"
                index = q
                lastq = q
            elif q is None:
                alt = insertion_alternative(i)
                index = lastq
            else:
                index = q
                lastq = q
                alt = read.seq[p]
            if q is not None or aligned_pos[i-1][1] is not None:
                if index not in coverage.keys():
                    coverage[index] = {}
                if alt not in coverage[index].keys():
                    coverage[index][alt] = 1
                else:
                    coverage[index][alt] += 1

    for r in reads:
        count_coverage_read(r)

    for k in coverage.keys():
        coverage[k]['sum'] = sum([coverage[k][base] for base in coverage[k].keys()])

    return coverage


def filter_positions(coverage, relative, absolute, reference):
    positions = {k for k in coverage.keys() for base in coverage[k].keys() if
                 base != 'sum' and base != reference[k] and coverage[k][base] >= absolute and coverage[k][base] / coverage[k]['sum'] >= relative}
    return sorted(positions)


def header():
    h = '''##fileformat=VCFv4.0
##fileDate=20201220
##source=MN908947.3
##reference=1000GenomesPilot-NCBI36
##phasing=partial
##INFO=<ID=S,Number=1,Type=Integer,Description="Minimal number of reads supporting variant">
##INFO=<ID=AS,Number=1,Type=Integer,Description="Average number of reads supporting variant">
##INFO=<ID=Q,Number=1,Type=Integer,Description="Variant probability">
##INFO=<ID=AQ,Number=1,Type=Integer,Description="Average variant probability">
#'''
    columns = ['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'INFO']
    return h + '\t'.join(columns) + '\n'


def aligned_positions(read):
    pairs = read.get_aligned_pairs()
    indexes = [(pairs[i], i) for i in range(len(pairs)) if pairs[i][1] is not None]
    if len(indexes) == 0:
        return []
    first = indexes[0]
    last = indexes[-1]
    return [pairs[i] for i in range(len(pairs)) if i in range(first[1], last[1])]


def align_fasta(fastas, reference):
    for fasta in fastas:
        path = "{}.bam".format(fasta)
        os.system("minimap2 -x map-ont -a --secondary=no {} {} | samtools view -S -b - | samtools sort - -o {}".format(reference, fasta, path))
        os.system("samtools index {}".format(path))


def write_vcf(vcffile, variants, verbose, endline='\n'):
    lines = ['\t'.join([str(v[field]) for field in v.keys()]) for v in variants]
    with open(vcffile, 'w') as vcf:
        vcf.write(header())
        for v in lines:
            vcf.write(v + endline)
            if verbose:
                print(v)


class VCFReader:
    def __init__(self, vcf):
        self._filename = vcf
        with open(self._filename) as vcffile:
            self._lines = vcffile.readlines()
            self._header = [line for line in self._lines if line.startswith('#')]
            self._fields = self._header[-1].split('\t')
            self._fields[0] = self._fields[0][1:]
            self._fields[-1] = self._fields[-1][:-1]
            self.parse_vcf()

    def parse_vcf(self):
        variant_lines = [line.split('\t') for line in self._lines if not line.startswith('#')]
        fields = self._fields
        self._variants = [{field: value for field, value in zip(fields, v)} for v in variant_lines]
        for v in self._variants:
            v['POS'] = int(v['POS'])
            v['QUAL'] = float(v['QUAL'])
            v['INFO'] = v['INFO'].replace('\n', '')

    def variants(self):
        return self._variants


def parse_vcf(vcffile):
    return VCFReader(vcffile).variants()


def patch_vcf(vcffile):
    variants = parse_vcf(vcffile)
    for v in variants:
        v['POS'] -= 1
    write_vcf(vcffile=vcffile, variants=variants, verbose=False, endline='')

