#!/usr/bin/env python

from __future__ import print_function
import sys
import pysam
import collections
import re
import os


def get_positions(file):
    position_list = collections.defaultdict(list)
    for line in open(file, "rb"):
        field = line.rstrip().split("\t")
        chrom = field[0]
        gene = field[1]
        sites = field[2:]
        s = sites[0::2]
        e = sites[1::2]
        for i, j in zip(s, e):
            position_list[gene].append((chrom, i, j))
    return position_list


def get_feature_reads(bam, position_list):
    samfile = pysam.AlignmentFile(bam, "rb")
    cov = {}
    for f in position_list:
        read_num = 0
        f_length = 0
        for chrom, s, e in position_list[f]:
            f_length = f_length + (int(e) - int(s) + 1)
            r_names = collections.defaultdict(int)
            try:
                for read in samfile.fetch(chrom, int(s), int(e)):
                    if r_names[read.query_name]: continue
                    res = re.search('N', read.cigarstring)
                    if res:
                        base = []
                        read_start = read.pos + 1
                        rang = range(int(s), int(e) + 1)
                        for c, j in read.cigartuples:
                            if c == 0:
                                base.extend(range(read_start, read_start + j))
                                read_start = read_start + j
                            elif c == 2:
                                read_start = read_start + j
                            elif c == 3:
                                read_start = read_start + j
                        if set(base).intersection(rang):
                            read_num += 1
                            r_names[read.query_name] = 1
                    else:
                        read_num += 1
                        r_names[read.query_name] = 1
            except ValueError:
                continue
        cov[f] = [read_num, f_length]
    return cov


if __name__ == "__main__":
    positions = get_positions(sys.argv[1])
    reads = get_feature_reads(sys.argv[2], positions)
    base_name = os.path.basename(sys.argv[2]).split("_")[0] + "_" + os.path.basename(sys.argv[2]).split("_")[1]
    OUT = open(sys.argv[4], "w")
    print("gene\t", base_name, "_raw_reads\t", base_name, "_normalized_reads\t", base_name, "_fpkm\t", base_name, "_tpm\t", sep='', file=OUT)
    k = reads.keys()
    k.sort()
    s = 0
    for i in k:
        t = float(reads[i][0]) / reads[i][1]
        s += t
    for i in k:
        fpkm = float(reads[i][0]) * 1000000000 / (reads[i][1] * float(sys.argv[3]))
        n_read = reads[i][0] / float(sys.argv[3]) * 10000000
        t = float(reads[i][0]) / reads[i][1]
        tmp = t * 1000000 / s
        print(i, str(reads[i][0]), str(n_read), str(fpkm), str(tmp), sep='\t', file=OUT)
