#!/usr/bin/env python
# -*- coding:utf-8 -*-

from __future__ import print_function
import sys
import collections


def assign_peak_to_gene_sections(p_summits, gene_sections):
    """As a error raised (the amount of summits was not equal to gene feature assigned results),
    then I temporarily set a file (tmpf) to collect all gene summits used to assign to different features."""
    count = collections.defaultdict(int)
    tmpf = open(sys.argv[4], "w")
    for k in sorted(gene_sections.keys()):
        fields = k.split('_')
        if len(fields) > 2:
            section = fields[1] + '_' + fields[2]
        else:
            section = fields[1]
        if fields[0] in p_summits.keys():
            for s in p_summits[fields[0]]:
                if s in gene_sections[k]:
                    count[section] += 1
                    print(fields[0], section, s, sep="\t", file=tmpf)
    return count


def summit(infile):
    peak_summits = collections.defaultdict(list)
    for line in open(infile, "r"):
        fields = line.rstrip().split()
        gene_id = fields[0].split('_')[1]
        peak_summits[gene_id].append(int(fields[1]))
    return peak_summits


def section_site(infile):
    """There are only 1 base defined as utr5/utr3 of some genes.
    Therefore, we use /if len(borders) > 1: continue/ to skip these exceptional cases."""

    sites = collections.defaultdict(list)
    for line in open(infile, "r"):
        fields = line.rstrip().split()
        gene_section = [fields[1] + '_utr5', fields[1] + '_cds', fields[1] + '_utr3']
        for t in range(5, 8):
            gs_index = t - 5
            borders = [int(i) for i in fields[t].rstrip(',').split(',')]
            if len(borders) < 2: continue
            for j in range(0, len(borders), 2):
                sites[gene_section[gs_index]].extend(range(borders[j], borders[j+1]+1))
    return sites


if __name__ == '__main__':
    summits = summit(sys.argv[1])
    section_sites = section_site(sys.argv[2])
    peak_distribution = assign_peak_to_gene_sections(summits, section_sites)
    outfile = open(sys.argv[3], "w")
    print('sections', 'num', 'percentage', sep="\t", file=outfile)
    for key in peak_distribution.keys():
        total = sum(peak_distribution.values())
        pert = peak_distribution[key] / float(total)
        print(key, peak_distribution[key], pert, sep='\t', file=outfile)
