#!/usr/bin/env python

from __future__ import print_function
import pysam
import sys
import collections
import re


def get_base_reads(bam, mreads, position_list):
	of = open(sys.argv[4], "w")
	samfile = pysam.AlignmentFile(bam, "rb")
	for f in position_list:
		cov = []
		for chr, start, end in position_list[f]:
			ref_start = int(start) - 1
			ref_end = int(end)
			seg = [0] * (ref_end - ref_start)
			for pileupcolumn in samfile.pileup(chr, ref_start, ref_end, truncate = True, max_depth=500000):
				r_names = []
				seg_index = pileupcolumn.pos - ref_start
				if pileupcolumn.n == 0:
					seg[seg_index] = 0
					continue
				cover_read = 0
				for pileupread in pileupcolumn.pileups:
					if pileupread.is_del: continue
					if pileupread.alignment.query_name in r_names: continue
					if pileupread.alignment.is_unmapped: continue
					cover_read += 1
					r_names.append(pileupread.alignment.query_name)
				seg[seg_index] = cover_read * 10000000 / float(mreads)
			cov.extend(seg)
		print(f, "\t".join(str(i) for i in cov), sep="\t", file=of)


def get_peak_position(file):
	position_list = collections.defaultdict(list)
	for line in open(file, "rb"):
		field = line.rstrip().split("\t")
		chrom = field[0]
		gene = field[1]
		sites = field[2:]
		s = sites[0::2]
		e = sites[1::2]
		for i,j in zip(s, e):
			position_list[gene].append((chrom, i, j))
	return position_list


if __name__ == "__main__":
	position_list = get_peak_position(sys.argv[1])
	get_base_reads(sys.argv[2], sys.argv[3], position_list)
