#!/usr/bin/env python
# -*- coding:utf-8 -*-

from __future__ import print_function
from __future__ import division
import sys
import collections
import pysam
import re
import os

def gene_range(infile):
    g_range = {}
    for line in open(infile, "r"):
        fields = line.rstrip().split()
        g_id = fields[1]
#        print(g_id, sep='\n')
        for se in range(2, len(fields), 2):
            site = [(fields[0] + "_" + str(i)) for i in range(int(fields[se]), int(fields[se+1])+1)]
            for s in site:
                g_range[s] = g_id
    return g_range


def snp_site(infile, g_range):
    gene_snp = collections.defaultdict(list)
    for line in open(infile, "r"):
        if line.startswith('chr'): continue
        fields = line.rstrip().split()
        pos = fields[0] + "_" + fields[1]
        ref = fields[2]
        mut = fields[3]
        if pos in g_range:
            gene_snp[g_range[pos]].append((pos, ref, mut))
    return gene_snp


def identify_indel(mut, pileupcolumn):
    err_read_n = []
    mut_read_n = []
    ref_read_n = []
    for read in pileupcolumn.pileups:
        if read.is_refskip: continue
        if read.is_del:
            err_read_n.append(read.alignment.query_name)
            continue
        query_pos = read.query_position + 1

        if read.indel:
            for i, j in read.alignment.cigartuples:
                if i == 0:
                    query_pos = query_pos - j
                elif i == 2 | i == 3: continue
                elif i == 1:
                    query_pos = query_pos - j
                if query_pos == 0:
                    real_indel = ''
                    num = read.alignment.cigartuples.index((i, j))
                    print(read.alignment.query_name, read.alignment.cigartuples, query_pos, num)
                    indel_len = read.alignment.cigartuples[num+1][1]
                    indel_type = read.alignment.cigartuples[num+1][0]
                    if indel_type == 1:
                        indel_base = read.alignment.query_sequence[(read.query_position+1):(read.query_position+indel_len+1)]
                        real_indel = '+' + str(indel_len) + indel_base
                    elif indel_type == 2:
                        indel_base = read.alignment.get_reference_sequence()[(read.query_position+1):(read.query_position+indel_len+1)]
                        real_indel = '-' + str(indel_len) + indel_base
                    if real_indel != mut:
                        err_read_n.append(read.alignment.query_name)
                    else:
                        mut_read_n.append(read.alignment.query_name)
                    break
        else:
            ref_read_n.append(read.alignment.query_name)
    err_s_rdp = {}.fromkeys(err_read_n).keys()
    mut_s_rdp = {}.fromkeys(mut_read_n).keys()  # remove double counting reads
    ref_s_rdp = {}.fromkeys(ref_read_n).keys()
    return err_s_rdp, mut_s_rdp, ref_s_rdp


def identity_snp(ref, mut, pileupcolumn):
    err_read_n = []
    mut_read_n = []
    ref_read_n = []
    for read in pileupcolumn.pileups:
        if read.is_refskip: continue
        if read.is_del:
            err_read_n.append(read.alignment.query_name)
            continue
        if read.alignment.query_sequence[read.query_position] == ref:
            ref_read_n.append(read.alignment.query_name)
        elif read.alignment.query_sequence[read.query_position] == mut:
            mut_read_n.append(read.alignment.query_name)
        else:
            err_read_n.append(read.alignment.query_name)
    err_s_rdp = {}.fromkeys(err_read_n).keys()
    mut_s_rdp = {}.fromkeys(mut_read_n).keys()  # remove double counting reads
    ref_s_rdp = {}.fromkeys(ref_read_n).keys()
    return err_s_rdp, mut_s_rdp, ref_s_rdp


def separate_read(bam1, bam2, gene_snp):
    """If there are two genes overlapped with each other and reads mapping on the overlapped region, then these reads will
    be counted twice. mut_read_n and ref_read_n are used to avoid double counting and remove the confused reads, which covered 
    two or more snps but one is on the read, the other one is not. """
    samfile1 = pysam.Samfile(bam1, "rb")
    samfile2 = pysam.Samfile(bam2, "rb")

    ref_g_read1 = collections.defaultdict(int)
    mut_g_read1 = collections.defaultdict(int)
    ref_g_read2 = collections.defaultdict(int)
    mut_g_read2 = collections.defaultdict(int)

    snp_number = {}
    discard_file1 = os.path.basename(sys.argv[3]) + ".discard_sites_v3"
    discard_file2 = os.path.basename(sys.argv[4]) + ".discard_sites_v3"
    ds1 = open(discard_file1, "w")
    ds2 = open(discard_file2, "w")

    for g in sorted(gene_snp.keys()):
        mut_read_rdp1 = []
        ref_read_rdp1 = []
        mut_read_rdp2 = []
        ref_read_rdp2 = []
        informative_snp = 0
        for f, ref, mut in gene_snp[g]:
            chrom = f.split("_")[0]
            mut_site = int(f.split("_")[1])         # mut_site is 1-based
            start = mut_site - 1
            end = mut_site + 1        # end canbe modified to mut_site
            er1 = er2 = rr1 = rr2 = mr1 = mr2 = []
            all_read1 = all_read2 = real_depth1 = real_depth2 = 0
            for pileupcolumn1 in samfile1.pileup(chrom, start, end, truncate=True, max_depth=500000):
                ref_pos = pileupcolumn1.pos + 1
                if ref_pos == mut_site:
                    if re.match('[\+-]', mut):
                        er1, mr1, rr1 = identify_indel(mut, pileupcolumn1)
                    else:
                        er1, mr1, rr1 = identity_snp(ref, mut, pileupcolumn1)
                    real_depth1 = len(mr1) + len(rr1)
                    all_read1 = len(er1) + real_depth1
            if all_read1 != 0: 
                if len(er1) / all_read1 >= 0.1:
                    print(g, f, ref, mut, len(er1), all_read1, sep="\t", file=ds1)
                    continue
            for pileupcolumn2 in samfile2.pileup(chrom, start, end, truncate=True, max_depth=500000):
                ref_pos = pileupcolumn2.pos + 1
                if ref_pos == mut_site:
                    if re.match('[\+-]', mut):
                        er2, mr2, rr2 = identify_indel(mut, pileupcolumn2)
                    else:
                        er2, mr2, rr2 = identity_snp(ref, mut, pileupcolumn2)
                    real_depth2 = len(mr2) + len(rr2)
                    all_read2 = len(er2) + real_depth2
            if all_read2 != 0:
                if len(er2) / all_read2 >= 0.1:
                    print(g, f, ref, mut, len(er2), all_read2, sep="\t", file=ds2)
                    continue
            real_depth = real_depth1 + real_depth2
            if real_depth >= 20:
                mut_read_rdp1.extend(mr1)
                ref_read_rdp1.extend(rr1)
                mut_read_rdp2.extend(mr2)
                ref_read_rdp2.extend(rr2)
                informative_snp += 1

        if informative_snp == 0: continue
        mut_conviced_read1 = set(mut_read_rdp1) - set(ref_read_rdp1)                     # remove confused reads
        ref_conviced_read1 = set(ref_read_rdp1) - set(mut_read_rdp1)
        mut_conviced_read2 = set(mut_read_rdp2) - set(ref_read_rdp2)                     # remove confused reads
        ref_conviced_read2 = set(ref_read_rdp2) - set(mut_read_rdp2)

        mut_g_read1[g] = len(mut_conviced_read1)
        ref_g_read1[g] = len(ref_conviced_read1)
        mut_g_read2[g] = len(mut_conviced_read2)
        ref_g_read2[g] = len(ref_conviced_read2)

        snp_number[g] = (len(gene_snp[g]), informative_snp)
    return ref_g_read1, mut_g_read1, ref_g_read2, mut_g_read2, snp_number


if __name__ == '__main__':
    gene = gene_range(sys.argv[1])
    g_snp = snp_site(sys.argv[2], gene)
    refr1, mutr1, refr2, mutr2, snp = separate_read(sys.argv[3], sys.argv[4], g_snp)
    of = open(sys.argv[5], "w")
    print('gene', 'ref_read_num1', 'mutants1', 'ref_read_num2', 'mutants2', 'total_snp', 'informative_snp', sep='\t', file=of)
    for gene in sorted(refr1.keys()):
        print(gene, refr1[gene], mutr1[gene], refr2[gene], mutr2[gene], snp[gene][0], snp[gene][1], sep='\t', file=of)
