Skip to content

Commit

Permalink
Merge pull request #144 from marbl/overhaul-coordinates
Browse files Browse the repository at this point in the history
Overhaul coordinates
  • Loading branch information
bkille authored Jan 19, 2024
2 parents 1b5b48b + 2f81014 commit 2eeaa25
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 168 deletions.
91 changes: 53 additions & 38 deletions extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@
from Bio.Seq import Seq
from Bio.SeqIO import SeqRecord
from Bio.Align import MultipleSeqAlignment
from glob import glob
import tempfile
import logging
from pathlib import Path
import re
import subprocess
from collections import namedtuple, defaultdict, Counter
import os
from Bio.Align import substitution_matrices
from itertools import product, combinations
import bisect
import numpy as np
from Bio.AlignIO.MafIO import MafWriter, MafIterator
from Bio.AlignIO.MauveIO import MauveWriter, MauveIterator
from logger import logger
import time
from tqdm import tqdm
from logger import logger, TqdmToLogger, MIN_TQDM_INTERVAL
import spoa
#%%


Expand Down Expand Up @@ -46,29 +41,38 @@ def parse_xmfa_header(xmfa_file):
return index_to_gid, gid_to_index


def index_input_sequences(xmfa_file, input_dir):
def index_input_sequences(xmfa_file, file_list):
basename_to_path = {}
for f in file_list:
basename = str(Path(f).stem)
basename_to_path[basename] = f
gid_to_records = {}
gid_to_cid_to_index = {}
gid_to_index_to_cid = {}
with open(xmfa_file) as parsnp_fd:
for line in (line.strip() for line in parsnp_fd):
if line[:2] == "##":
if line.startswith("##SequenceFile"):
p = Path(os.path.join(input_dir + line.split(' ')[1]))
gid_to_records[p.stem] = {record.id: record for record in SeqIO.parse(str(p), "fasta")}
gid_to_cid_to_index[p.stem] = {idx+1: rec.id for (idx, rec) in enumerate(SeqIO.parse(str(p), "fasta"))}
return gid_to_records, gid_to_cid_to_index
basename = Path(line.split(' ')[1]).stem
p = Path(basename_to_path[basename])
gid_to_records[p.stem] = {}
gid_to_cid_to_index[p.stem] = {}
gid_to_index_to_cid[p.stem] = {}
for idx, rec in enumerate(SeqIO.parse(str(p), "fasta")):
gid_to_records[p.stem][rec.id] = rec
gid_to_cid_to_index[p.stem][rec.id] = idx + 1
gid_to_index_to_cid[p.stem][idx + 1] = rec.id
return gid_to_records, gid_to_cid_to_index, gid_to_index_to_cid



def xmfa_to_covered(xmfa_file, index_to_gid, gid_to_cid_to_index):
def xmfa_to_covered(xmfa_file, index_to_gid, gid_to_index_to_cid):
seqid_parser = re.compile(r'^cluster(\d+) s(\d+):p(\d+)/.*')
idpair_to_segments = defaultdict(list)
idpair_to_tree = defaultdict(IntervalTree)
cluster_to_named_segments = defaultdict(list)
for aln in tqdm(AlignIO.parse(xmfa_file, "mauve")):
for aln in AlignIO.parse(xmfa_file, "mauve"):
for seq in aln:
# Skip reference for now...
aln_len = seq.annotations["end"] - seq.annotations["start"] + 1
aln_len = seq.annotations["end"] - seq.annotations["start"]
cluster_idx, contig_idx, startpos = [int(x) for x in seqid_parser.match(seq.id).groups()]

gid = index_to_gid[seq.name]
Expand All @@ -78,29 +82,29 @@ def xmfa_to_covered(xmfa_file, index_to_gid, gid_to_cid_to_index):
else:
endpos = startpos + aln_len

idp = IdPair(gid, gid_to_cid_to_index[gid][contig_idx])
idp = IdPair(gid, gid_to_index_to_cid[gid][contig_idx])
seg = Segment(idp, startpos, startpos + aln_len, seq.annotations["strand"])
idpair_to_segments[idp].append(seg)
idpair_to_tree[idp].addi(seg.start, seg.stop)
cluster_to_named_segments[cluster_idx].append(seg)

for idp in idpair_to_segments:
idpair_to_segments[idp] = sorted(idpair_to_segments[idp])
idpair_to_tree[idp].merge_overlaps()
return idpair_to_segments, idpair_to_tree, cluster_to_named_segments
return idpair_to_segments, cluster_to_named_segments


def run_msa(downstream_segs_to_align, gid_to_records):
keep_extending = True
iteration = 0
seq_len_desc = stats.describe([seg.stop - seg.start for seg in downstream_segs_to_align])
longest_seq = seq_len_desc.minmax[1]
if sum(
seq_len_desc.mean*(1 - length_window) <= (seg.stop - seg.start) <= seq_len_desc.mean*(1 + length_window) for seg in downstream_segs_to_align) > len(downstream_segs_to_align)*window_prop:
base_length = int(seq_len_desc.mean*(1 + length_window))
else:
base_length = BASE_LENGTH
seq_lens = [seg.stop - seg.start for seg in downstream_segs_to_align]
longest_seq = max(seq_lens)
mean_seq_len = np.mean(seq_lens)
# if sum(
# mean_seq_len*(1 - length_window) <= (seg.stop - seg.start) <= mean_seq_len*(1 + length_window) for seg in downstream_segs_to_align) > len(downstream_segs_to_align)*window_prop:
# base_length = int(mean_seq_len*(1 + length_window))
# else:
# base_length = BASE_LENGTH

base_length = BASE_LENGTH
while keep_extending:
seqs_to_align = ["A" + (str(
gid_to_records[seg.idp.gid][seg.idp.cid].seq[seg.start:seg.stop] if seg.strand == 1
Expand Down Expand Up @@ -131,11 +135,15 @@ def run_msa(downstream_segs_to_align, gid_to_records):
return aligned_msa_seqs


def extend_clusters(xmfa_file, index_to_gid, gid_to_cid_to_index, idpair_to_segments, idpair_to_tree, cluster_to_named_segments, gid_to_records):
def extend_clusters(xmfa_file, gid_to_index, gid_to_cid_to_index, idpair_to_segments, cluster_to_named_segments, gid_to_records):
ret_lcbs = []
seqid_parser = re.compile(r'^cluster(\d+) s(\d+):p(\d+)/.*')

for aln_idx, aln in enumerate(tqdm(AlignIO.parse(xmfa_file, "mauve"), total=len(cluster_to_named_segments))):
for aln_idx, aln in enumerate(tqdm(
AlignIO.parse(xmfa_file, "mauve"),
total=len(cluster_to_named_segments),
file=TqdmToLogger(logger, level=logging.INFO),
mininterval=MIN_TQDM_INTERVAL)):
# validate_lcb(aln, gid_to_records, parsnp_header=True)
seq = aln[0]
cluster_idx, contig_idx, startpos = [int(x) for x in seqid_parser.match(seq.id).groups()]
Expand Down Expand Up @@ -167,29 +175,36 @@ def extend_clusters(xmfa_file, index_to_gid, gid_to_cid_to_index, idpair_to_segm
new_lcb = MultipleSeqAlignment([])
# Assumes alignments are always in the same order
new_bp = []
for seq_idx, (covered_seg, uncovered_seg, aln_str) in enumerate(zip(segs, downstream_segs_to_align, aligned_msa_seqs)):
for seg_idx, (covered_seg, uncovered_seg, aln_str) in enumerate(zip(segs, downstream_segs_to_align, aligned_msa_seqs)):
# Update segment in idpair_to_segments
if len(aln_str) < MIN_LEN:
continue
new_bp_covered = len(aln_str) - aln_str.count("-")
# print(f"Extending {covered_seg} by {new_bp_covered}")
new_bp.append(new_bp_covered)
new_seq = aln_str
if covered_seg.strand == 1:
new_seg = Segment(covered_seg.idp, uncovered_seg.start, uncovered_seg.start + new_bp_covered, covered_seg.strand)
if new_bp_covered > 0:
segs[seg_idx] = Segment(covered_seg.idp, covered_seg.start, new_seg.stop, covered_seg.strand)
else:
aln_str = Seq(aln_str).reverse_complement()
new_seg = Segment(covered_seg.idp, covered_seg.start - new_bp_covered, covered_seg.start, covered_seg.strand)
if new_bp_covered > 0:
segs[seg_idx] = Segment(covered_seg.idp, new_seg.start, covered_seg.stop, covered_seg.strand)

new_record = SeqRecord(
seq=new_seq,
id=f"{covered_seg.idp.gid}#{covered_seg.idp.cid}",
id=f"cluster{cluster_idx} s{gid_to_cid_to_index[covered_seg.idp.gid][covered_seg.idp.cid]}:p{new_seg.start if new_seg.strand == 1 else new_seg.stop}",
name=gid_to_index[covered_seg.idp.gid],
annotations={"start": new_seg.start, "end": new_seg.stop, "strand": new_seg.strand}
)

# if covered_seg.strand == 1:
new_lcb.append(new_record)
if new_bp_covered > 0:
idpair_to_tree[covered_seg.idp].addi(new_seg.start, new_seg.stop)

ret_lcbs.append(new_lcb)
if len(new_lcb) > 0:
ret_lcbs.append(new_lcb)
return ret_lcbs


23 changes: 23 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import io
############################################# Logging ##############################################
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
#These are the sequences need to get colored ouput
Expand All @@ -14,6 +15,28 @@
COLOR_SEQ = "\033[1;%dm"
BOLD_SEQ = "\033[1m"

MIN_TQDM_INTERVAL=30


# Logging redirect copied from https://stackoverflow.com/questions/14897756/python-progress-bar-through-logging-module
class TqdmToLogger(io.StringIO):
"""
Output stream for TQDM which will output to logger module instead of
the StdOut.
"""
logger = None
level = None
buf = ''
def __init__(self,logger,level=None):
super(TqdmToLogger, self).__init__()
self.logger = logger
self.level = level or logging.INFO
def write(self,buf):
self.buf = buf.strip('\r\n\t ')
def flush(self):
self.logger.log(self.level, self.buf)


def formatter_message(message, use_color = True):
if use_color:
message = message.replace("$RESET", RESET_SEQ).replace("$BOLD", BOLD_SEQ)
Expand Down
96 changes: 42 additions & 54 deletions parsnp
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@
'''

import os, sys, string, getopt, random,subprocess, time,operator, math, datetime,numpy #pysam
import os, sys, string, random, subprocess, time, operator, math, datetime, numpy #pysam
from collections import defaultdict
import csv
import shutil
import shlex
from tempfile import TemporaryDirectory
import re
import logging
from logger import logger
import multiprocessing
from logger import logger, TqdmToLogger, MIN_TQDM_INTERVAL
import argparse
import signal
import inspect
from multiprocessing import Pool
from Bio import SeqIO
from glob import glob
Expand All @@ -27,7 +24,7 @@ from pathlib import Path
import extend as ext
from tqdm import tqdm

__version__ = "2.0.1"
__version__ = "2.0.2"
reroot_tree = True #use --midpoint-reroot
random_seeded = random.Random(42)

Expand Down Expand Up @@ -149,7 +146,7 @@ def run_phipack(query,seqlen,workingdir):
currdir = os.getcwd()
os.chdir(workingdir)
command = "Profile -o -v -n %d -w 100 -m 100 -f %s > %s.out"%(seqlen,query,query)
run_command(command,1, prepend_time=True)
run_command(command, 1)
os.chdir(currdir)

def run_fasttree(query,workingdir,recombination_sites):
Expand Down Expand Up @@ -685,15 +682,20 @@ def create_output_directory(output_dir):

if os.path.exists(output_dir):
logger.warning(f"Output directory {output_dir} exists, all results will be overwritten")
shutil.rmtree(output_dir)
if os.path.exists(output_dir + "/partition"):
shutil.rmtree(output_dir + "/partition/")
if os.path.exists(output_dir + "/config/"):
shutil.rmtree(output_dir + "/config/")
if os.path.exists(output_dir + "/log/"):
shutil.rmtree(output_dir + "/log/")
elif output_dir == "[P_CURRDATE_CURRTIME]":
today = datetime.datetime.now()
timestamp = "P_" + today.isoformat().replace("-", "_").replace(".", "").replace(":", "").replace("T", "_")
output_dir = os.getcwd() + os.sep + timestamp
os.makedirs(output_dir)
os.makedirs(os.path.join(output_dir, "tmp"))
os.makedirs(os.path.join(output_dir, "log"))
os.makedirs(os.path.join(output_dir, "config"))
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "tmp"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "log"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "config"), exist_ok=True)
return output_dir


Expand Down Expand Up @@ -1645,7 +1647,11 @@ SETTINGS:
logger.info("Running partitions...")
good_chunks = set(chunk_labels)
with Pool(args.threads) as pool:
return_codes = tqdm(pool.imap(run_parsnp_aligner, chunk_output_dirs, chunksize=1), total=len(chunk_output_dirs))
return_codes = tqdm(
pool.imap(run_parsnp_aligner, chunk_output_dirs, chunksize=1),
total=len(chunk_output_dirs),
file=TqdmToLogger(logger,level=logging.INFO),
mininterval=MIN_TQDM_INTERVAL)
for cl, rc in zip(chunk_labels, return_codes):
if rc != 0:
logger.error(f"Partition {cl} failed...")
Expand All @@ -1666,51 +1672,33 @@ SETTINGS:
partition.merge_xmfas(partition_output_dir, chunk_labels, xmfa_out_f, num_clusters, args.threads)



run_lcb_trees = 0
parsnp_output = f"{outputDir}/parsnp.xmfa"

# This is the stuff for LCB extension:
annotation_dict = {}
#TODO always using xtrafast?
parsnp_output = f"{outputDir}/parsnp.xmfa"
if args.extend_lcbs:
xmfa_file = f"{outputDir}/parsnp.xmfa"
with TemporaryDirectory() as temp_directory:
original_maf_file = f"{outputDir}/parsnp-original.maf"
extended_xmfa_file = f"{outputDir}/parsnp-extended.xmfa"
fname_contigid_to_length, fname_contigidx_to_header, fname_to_seqrecord = ext.get_sequence_data(
ref,
finalfiles,
index_files=False)
fname_to_contigid_to_coords, fname_header_to_gcontigidx = ext.xmfa_to_maf(
xmfa_file,
original_maf_file,
fname_contigidx_to_header,
fname_contigid_to_length)
packed_write_result = ext.write_intercluster_regions(finalfiles + [ref], temp_directory, fname_to_contigid_to_coords)
fname_contigid_to_cluster_dir_to_length, fname_contigid_to_cluster_dir_to_adjacent_cluster = packed_write_result
cluster_files = glob(f"{temp_directory}/*.fasta")
clusterdir_expand, clusterdir_len = ext.get_new_extensions(
cluster_files,
args.match_score,
args.mismatch_penalty,
args.gap_penalty)
ext.write_extended_xmfa(
original_maf_file,
extended_xmfa_file,
temp_directory,
clusterdir_expand,
clusterdir_len,
fname_contigid_to_cluster_dir_to_length,
fname_contigid_to_cluster_dir_to_adjacent_cluster,
fname_header_to_gcontigidx,
fname_contigid_to_length,
args.extend_ani_cutoff,
args.extend_indel_cutoff,
threads)
parsnp_output = extended_xmfa_file
os.remove(original_maf_file)
logger.warning("The LCB extension module is experimental. Runtime may be significantly increased and extended alignments may not be as high quality as the original core-genome. Extensions off of existing LCBs are in a separate xmfa file.")
import partition
import extend as ext

orig_parsnp_xmfa = parsnp_output
extended_parsnp_xmfa = orig_parsnp_xmfa + ".extended"

# Index input fasta files and original xmfa
index_to_gid, gid_to_index = ext.parse_xmfa_header(orig_parsnp_xmfa)
gid_to_records, gid_to_cid_to_index, gid_to_index_to_cid = ext.index_input_sequences(orig_parsnp_xmfa, finalfiles + [ref])

# Get covered regions of xmfa file
idpair_to_segments, cluster_to_named_segments = ext.xmfa_to_covered(orig_parsnp_xmfa, index_to_gid, gid_to_index_to_cid)

# Extend clusters
logger.info(f"Extending LCBs with SPOA...")
new_lcbs = ext.extend_clusters(orig_parsnp_xmfa, gid_to_index, gid_to_cid_to_index, idpair_to_segments, cluster_to_named_segments, gid_to_records)

# Write output
partition.copy_header(orig_parsnp_xmfa, extended_parsnp_xmfa)
with open(extended_parsnp_xmfa, 'a') as out_f:
for lcb in new_lcbs:
partition.write_aln_to_xmfa(lcb, out_f)

#add genbank here, if present
if len(genbank_ref) != 0:
Expand Down
Loading

0 comments on commit 2eeaa25

Please sign in to comment.