#! /usr/bin/env python
#coding=utf-8

from __future__ import print_function
from __future__ import division
import os,sys
if sys.version_info[0] != 2 or sys.version_info[1] != 7:
    print(sys.stderr, "\nYou are using python" + str(sys.version_info[0]) + '.' + str(sys.version_info[1]) + " RSeQC needs python2.7!\n")
    sys.exit()

from optparse import OptionParser
import collections
import math
import pysam

__version__="1.0"


def genebody_percentile(ref_file, sec):
    g_percentiles = {}
    for line in open(ref_file,'r'):
        fields = line.rstrip().split()
        chrom = fields[0]
        geneName = fields[1]
        tx_start = int(fields[2])
        tx_end = int(fields[3])
        strand = fields[4]
        geneID = '_'.join([str(j) for j in (chrom, tx_start, tx_end, geneName, strand)])
        upstream = map(int, fields[5].rstrip(',\t').split(','))
        gene_body = map(int, fields[6].rstrip(',\t').split(','))
        downstream = map(int, fields[7].rstrip(',\t').split(','))

        up_sites = sites_for_cov(upstream, strand, sec[0])
        g_sites = sites_for_cov(gene_body, strand, sec[1])
        down_sites = sites_for_cov(downstream, strand, sec[2])

        if up_sites and g_sites and downstream:
            if strand == '+':
                positions = up_sites + g_sites + down_sites
            else:
                positions = down_sites + g_sites + up_sites
            g_percentiles[geneID] = (chrom, strand, positions)
    return g_percentiles

def sites_for_cov(borders, strand, site_num):
    bases = []
    starts = []
    ends = []
    if borders[0] == 0:
        sites = []
        return sites
    else:
        for i in range(0, len(borders), 2):
            starts.append(borders[i])
        for i in range(1, len(borders), 2):
            ends.append(borders[i])
        for st,ed in zip(starts, ends):
            bases.extend(range(st,ed+1))
        if len(bases) >= site_num:
            sites = percentile_list(bases, strand, site_num)
        else:
            sites = []
        return sites

def percentile_list(gene_base_list, strand, point_number):
    per_list = []
    list_length = len(gene_base_list)
    for i in range(1, point_number+1):
        k = list_length / point_number
        if strand == '+':
            index = int(round(k * i)) - 1
            per_list.append(gene_base_list[index])
        else:
            index = 0 - int(round(k * i))
            per_list.append(gene_base_list[index])
    if strand == '-':
        per_list = per_list[::-1]
    return per_list

def genebody_coverage(bam, position_list):
    samfile = pysam.Samfile(bam, "rb")
    aggreagated_cvg = collections.defaultdict(int)
    gene_finished = 0
    each_gene_cov_name = os.path.basename(bam) + '.each_gene_cov'
    tf = open(each_gene_cov_name, 'w')
    for gene,(chrom, strand, positions) in position_list.items():
        coverage = {}
        for i in positions:
            coverage[i] = 0
        chrom_start = positions[0]-1
        chrom_end = positions[-1]
        
        for pileupcolumn in samfile.pileup(chrom, chrom_start, chrom_end, truncate=True):
            ref_pos = pileupcolumn.pos+1
            if ref_pos not in positions:
                continue
            if pileupcolumn.n == 0:
                coverage[ref_pos] = 0
                continue
            cover_read = 0
            read_name = []
            for pileupread in pileupcolumn.pileups:
                if pileupread.alignment.query_name in read_name: continue
                if pileupread.is_del: continue
                if pileupread.alignment.is_qcfail:continue
                if pileupread.alignment.is_secondary:continue
                if pileupread.alignment.is_unmapped:continue
                if pileupread.alignment.is_duplicate:continue
                cover_read += 1
                read_name.append(pileupread.alignment.query_name)
            coverage[ref_pos] = cover_read
        tmp = [coverage[k] for k in sorted(coverage)]
        if strand == '-':
            tmp = tmp[::-1]
        for i in range(0,len(tmp)):
            aggreagated_cvg[i] += tmp[i]
        gene_finished += 1
        print(gene, "\t".join(str(x) for x in tmp), sep="\n", file=tf)
    return aggreagated_cvg, gene_finished


if __name__ == '__main__':
#    usage = "%prog [options]" + '\n' + __doc__ + '\n'
    parser = OptionParser()
    parser.add_option("-i","--input",action="store",type="string",dest="input_file",help='Input file in BAM format. [required]')
    parser.add_option("-r","--reference",action="store",type="string",dest="ref_gene",help='Reference genes. [required]')
    parser.add_option("-n","--point_number",action="store",type="string",dest="point_number",help="The number of point in the gene. [required]")
    parser.add_option("-o","--outfile",action="store",type="string",dest="output_file",help="output file. [required]")
    (options,args)=parser.parse_args()

    if not (options.output_file and options.input_file and options.ref_gene and options.point_number):
        parser.print_help()
        sys.exit(0)

    ref = options.ref_gene
    section_sites = map(int, options.point_number.split(","))

    f = open(options.output_file, 'w')

    gene_percentiles = genebody_percentile(ref, section_sites)
    print ("Get BAM file and process the reads coverage ...")
    (cvg, count) = genebody_coverage(options.input_file, gene_percentiles)
    print ("positions" + "\t" + "aggregated_reads" + "\t" + "intensity", file=f)
    for key in sorted(cvg.keys()):
        site_intensity = cvg[key] / count
        print (str(key+1) + "\t" + str(cvg[key]) + "\t" + str(site_intensity), file=f)
    f.close()




