#!/usr/bin/env python2.7
"""
Generate a VCF from a GAM and XG by splitting into GAM/VG chunks.
Chunks are then called in series, and the VCFs stitched together.
Any step whose expected output exists is skipped unles --overwrite 
specified.  
"""

import argparse, sys, os, os.path, random, subprocess, shutil, itertools, glob
import json

def parse_args(args):
    parser = argparse.ArgumentParser(description=__doc__, 
        formatter_class=argparse.RawDescriptionHelpFormatter)
        
    # General options
    parser.add_argument("xg_path", type=str,
                        help="input xg file")
    parser.add_argument("gam_path", type=str,
                        help="input alignment")
    parser.add_argument("path_name", type=str,
                        help="name of reference path in graph (ex chr21)")
    parser.add_argument("path_size", type=int,
                        help="size of the reference path in graph")
    parser.add_argument("sample_name", type=str,
                        help="sample name (ex NA12878)")
    parser.add_argument("out_dir", type=str,
                        help="directory where all output will be written")    
    parser.add_argument("--chunk", type=int, default=10000000,
                        help="chunk size")
    parser.add_argument("--overlap", type=int, default=2000,
                        help="amount of overlap between chunks")
    parser.add_argument("--filter_opts", type=str,
                        default="-r 0.9 -fu -s 2 -o 0 -q 15 --defray-ends 999",
                        help="options to pass to vg filter. wrap in \"\"")
    parser.add_argument("--pileup_opts", type=str,
                        default="-q 10",
                        help="options to pass to vg pileup. wrap in \"\"")
    parser.add_argument("--call_opts", type=str,
                        default="",
                        help="options to pass to vg call. wrap in \"\"")
    parser.add_argument("--threads", type=int, default=20,
                        help="number of threads to use in vg call and vg pileup")
    parser.add_argument("--overwrite", action="store_true",
                        help="always overwrite existing files")
                        
    args = args[1:]
        
    return parser.parse_args(args)

def merge_call_opts(contig, offset, length, call_opts, sample_name):
    """ combine input vg call  options with generated options, by adding user offset
    and overriding contigs, sample and sequence lenght"""
    user_opts = call_opts.split()
    user_offset, user_contig, user_ref, user_sample, user_length  = None, None, None, None, None
    for i, uo in enumerate(user_opts):
        if uo in ["-o", "--offset"]:
            user_offset = int(user_opts[i + 1])
            user_opts[i + 1] = str(user_offset + offset)
        elif uo in ["-c", "--contig"]:
            user_contig = user_opts[i + 1]
        elif uo in ["-r", "--ref"]:
            user_ref = user_opts[i + 1]
        elif uo in ["-S", "--sample"]:
            user_sample = user_opts[i + 1]
            user_opts[i + 1] = sample_name
        elif uo in ["-l", "--length"]:
            user_length = user_opts[i + 1]
    opts = " ".join(user_opts)
    if user_offset is None:
        opts += " -o {}".format(offset)
    if user_contig is None:
        opts += " -c {}".format(contig)
    if user_ref is None:
        opts += " -r {}".format(contig)        
    if user_sample is None:
        opts += " -S {}".format(sample_name)
    if user_length is None:
        opts += " -l {}".format(length)
    return opts
    
def run(cmd, proc_stdout = sys.stdout, proc_stderr = sys.stderr,
        check = True):
    """ run command in shell and throw exception if it doesn't work 
    """
    print cmd
    proc = subprocess.Popen(cmd, shell=True, bufsize=-1,
                            stdout=proc_stdout, stderr=proc_stderr)
    output, errors = proc.communicate()
    sts = proc.wait()
    if check is True and sts != 0:
        raise RuntimeError("Command: %s exited with non-zero status %i" % (cmd, sts))
    return output, errors

def make_chunks(path_name, path_size, chunk_size, overlap):
    """ compute chunks as BED (0-based) 3-tuples: ie
    (chr1, 0, 10) is the range from 0-9 inclusive of chr1
    """
    assert chunk_size > overlap
    covered = 0
    chunks = []
    while covered < path_size:
        start = max(0, covered - overlap)
        end = min(path_size, start + chunk_size)
        chunks.append((path_name, start, end))
        covered = end
    return chunks

def chunk_base_name(path_name, out_dir, chunk_i = None, tag= ""):
    """ centralize naming of output chunk-related files """
    bn = os.path.join(out_dir, "{}-chunk".format(path_name))
    if chunk_i is not None:
        bn += "-{}".format(chunk_i)
    return "{}{}".format(bn, tag)

def chunk_gam(gam_path, xg_path, path_name, out_dir, chunks, filter_opts, threads, overwrite):
    """ use vg filter to chunk up the gam """
    # make bed chunks
    chunk_path = os.path.join(out_dir, path_name + "_chunks.bed")
    with open(chunk_path, "w") as f:
        for chunk in chunks:
            f.write("{}\t{}\t{}\n".format(chunk[0], chunk[1], chunk[2]))
    # run vg filter on the gam
    if overwrite or not any(
            os.path.isfile(chunk_base_name(path_name, out_dir, i, ".gam")) \
               for i in range(len(chunks))):
        run("vg filter {} -x {} -R {} -B {} {} -t {}".format(
            gam_path, xg_path, chunk_path,
            os.path.join(out_dir, path_name + "-chunk"), filter_opts, threads))

def xg_path_node_id(xg_path, path_name, offset):
    """ use vg find to get the node containing a given path position """
    #NOTE: vg find -p range offsets are 0-based inclusive.  
    stdout, stderr = run("vg find -x {} -p {}:{}-{} | vg mod -o - | vg view -j - | jq .node[0].id".format(
        xg_path, path_name, offset, offset),
                         proc_stdout=subprocess.PIPE)
    return int(stdout)

def xg_path_predecessors(xg_path, path_name, node_id, context = 1):
    """ get nodes before given node in a path. """
    stdout, stderr = run("vg find -x {} -n {} -c {} | vg view -j -".format(
        xg_path, node_id, context), proc_stdout=subprocess.PIPE)

    # get our json graph
    j = json.loads(stdout)
    paths = j["path"]
    path = [x for x in paths if x["name"] == path_name][0]
    mappings = path["mapping"]
    assert len(mappings) > 0
    # check that we have a node_mapping
    assert len([x for x in mappings if x["position"]["node_id"] == node_id]) == 1
    # collect mappings that come before
    out_ids = []
    for mapping in mappings:
        if mapping["position"]["node_id"] == node_id:
            break
        out_ids.append(mapping["position"]["node_id"])
    return out_ids

def chunk_vg(xg_path, path_name, out_dir, chunks, chunk_i, overwrite):
    """ use vg find to make one chunk of the graph """
    chunk = chunks[chunk_i]
    vg_chunk_path = chunk_base_name(chunk[0], out_dir, chunk_i, ".vg")
    if overwrite or not os.path.isfile(vg_chunk_path):
        first_node = xg_path_node_id(xg_path, chunk[0], int(chunk[1]))
        # xg_path query takes 0-based inclusive coordinates, so we
        # subtract 1 below to convert from BED chunk (0-based exlcusive)
        last_node = xg_path_node_id(xg_path, chunk[0], chunk[2] - 1)
        assert first_node > 0 and last_node >= first_node
        # todo: would be cleaner to not have to pad context here
        run("vg find -x {} -r {}:{} -c 1 > {}".format(
            xg_path, first_node, last_node, vg_chunk_path))
        # but because we got a context, manually go in and make sure
        # our path starts at first_node by deleting everything before
        left_path_padding = xg_path_predecessors(xg_path, path_name, first_node,
                                                 context = 1)
        for destroy_id in left_path_padding:
            # destroy should take node list
            run("vg mod -y {} {} | vg mod -o - > {}".format(
                destroy_id, vg_chunk_path, vg_chunk_path + ".destroy"))
            run("mv {} {}".format(
                vg_chunk_path + ".destroy", vg_chunk_path))
                
def xg_path_node_offset(xg_path, path_name, offset):
    """ get the offset of the node containing the given position of a path
    """
    # first we find the node
    node_id = xg_path_node_id(xg_path, path_name, offset)

    # now we find the offset of the beginning of the node
    stdout, stderr = run("vg find -x {} -P {} -n {}".format(
        xg_path, path_name, node_id),
                         proc_stdout=subprocess.PIPE)
    toks = stdout.split()
    # if len > 2 then we have a cyclic path, which we're assuming we don't
    assert len(toks) == 2
    assert toks[0] == str(node_id)
    node_offset = int(toks[1])
    # node_offset must be before
    assert node_offset <= offset
    # sanity check (should really use node size instead of 1000 here)
    assert offset - node_offset < 1000

    return node_offset
    
def sort_vcf(vcf_path, sorted_vcf_path):
    """ from vcflib """
    run("head -10000 {} | grep \"^#\" > {}".format(
        vcf_path, sorted_vcf_path))
    run("grep -v \"^#\" {} | sort -k1,1d -k2,2n >> {}".format(
        vcf_path, sorted_vcf_path))
    
def call_chunk(xg_path, path_name, out_dir, chunks, chunk_i, path_size, overlap,
               pileup_opts, call_options, sample_name, threads,
               overwrite):
    """ create VCF from a given chunk """
    # make the graph chunk
    chunk_vg(xg_path, path_name, out_dir, chunks, chunk_i, overwrite)

    chunk = chunks[chunk_i]
    path_name = chunk[0]
    vg_path = chunk_base_name(path_name, out_dir, chunk_i, ".vg")
    gam_path = chunk_base_name(path_name, out_dir, chunk_i, ".gam")

    # a chunk can be empty if nothing aligns there.
    if not os.path.isfile(gam_path):
        sys.stderr.write("Warning: chunk not found: {}\n".format(gam_path))
        return
    
    # do the augmentation via pileup.  this is the most resource intensive step,
    # especially in terms of mermory used.
    aug_graph_path = chunk_base_name(path_name, out_dir, chunk_i, ".aug")
    support_path = chunk_base_name(path_name, out_dir, chunk_i, ".sup")
    translation_path = chunk_base_name(path_name, out_dir, chunk_i, ".trans")
    
    if overwrite or not all(os.path.isfile(f) for f in [
            aug_graph_path, support_path, translation_path]):
        run("vg augment -a pileup {} {} -t {} {} -Z {} -S {}  > {}".format(
            vg_path, gam_path, threads, pileup_opts, translation_path,
            support_path, aug_graph_path))

    # do the calling.
    vcf_path = chunk_base_name(path_name, out_dir, chunk_i, ".vcf")
    if overwrite or not os.path.isfile(vcf_path + ".gz"):
        offset = xg_path_node_offset(xg_path, chunk[0], chunk[1])
        merged_call_opts = merge_call_opts(chunk[0], offset, path_size,
                                           call_options, sample_name)
        run("vg call {} -b {} -s {} -z {} -t {} {} > {}".format(
            aug_graph_path, vg_path, support_path, translation_path,
            threads, merged_call_opts, vcf_path + ".us"))
        sort_vcf(vcf_path + ".us", vcf_path)
        run("rm {}".format(vcf_path + ".us"))
        run("bgzip {}".format(vcf_path))
        run("tabix -f -p vcf {}".format(vcf_path + ".gz"))

    # do the vcf clip
    left_clip = 0 if chunk_i == 0 else overlap / 2
    right_clip = 0 if chunk_i == len(chunks) - 1 else overlap / 2
    clip_path = chunk_base_name(path_name, out_dir, chunk_i, "_clip.vcf")
    if overwrite or not os.path.isfile(clip_path):
        call_toks = call_options.split()
        offset = 0
        if "-o" in call_toks:
            offset = int(call_toks[call_toks.index("-o") + 1])
        run("bcftools view -r {}:{}-{} {} > {}".format(
            path_name, offset + chunk[1] + left_clip + 1,
            offset + chunk[2] - right_clip, vcf_path + ".gz", clip_path))

            
def merge_vcf_chunks(out_dir, path_name, path_size, chunks, overwrite):
    """ merge a bunch of clipped vcfs created above, taking care to 
    fix up the headers.  everything expected to be sorted already """
    vcf_path = os.path.join(out_dir, path_name + ".vcf")
    if overwrite or not os.path.isfile(vcf_path):
        first = True
        for chunk_i, chunk in enumerate(chunks):
            clip_path = chunk_base_name(path_name, out_dir, chunk_i, "_clip.vcf")
            if os.path.isfile(clip_path):
                if first is True:
                    # copy everything including the header
                    run("cat {} > {}".format(clip_path, vcf_path))
                    first = False
                else:
                    # add on everythin but header
                    run("grep -v \"^#\" {} >> {}".format(clip_path, vcf_path), check=False)
                
    # add a compressed indexed version
    if overwrite or not os.path.isfile(vcf_path + ".gz"):
        run("bgzip -c {} > {}".format(vcf_path, vcf_path + ".gz"))
        run("tabix -f -p vcf {}".format(vcf_path + ".gz"))

def main(args):
    
    options = parse_args(args)

    if not os.path.isdir(options.out_dir):
        os.makedirs(options.out_dir)

    # make things slightly simpler as we split overlap
    # between adjacent chunks
    assert options.overlap % 2 == 0

    # compute overlapping chunks
    chunks = make_chunks(options.path_name, options.path_size,
                options.chunk, options.overlap)

    # split the gam in one go
    chunk_gam(options.gam_path, options.xg_path,
              options.path_name, options.out_dir,
              chunks, options.filter_opts, options.threads,
              options.overwrite)

    # call every chunk in series
    for chunk_i, chunk in enumerate(chunks):
        call_chunk(options.xg_path, options.path_name,
                   options.out_dir, chunks, chunk_i,
                   options.path_size, options.overlap,
                   options.pileup_opts, options.call_opts,
                   options.sample_name, options.threads,
                   options.overwrite)
    
    # stitch together the vcf
    merge_vcf_chunks(options.out_dir, options.path_name,
                     options.path_size,
                     chunks, options.overwrite)
    
if __name__ == "__main__" :
    sys.exit(main(sys.argv))
        
        
