Skip to content

Commit

Permalink
Use filename if possible when extracting reads
Browse files Browse the repository at this point in the history
  • Loading branch information
cjw85 committed Feb 28, 2019
1 parent 179f6a4 commit 4a8fc87
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 27 deletions.
110 changes: 83 additions & 27 deletions fast5_research/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from concurrent.futures import ProcessPoolExecutor, as_completed
import functools
import logging
from timeit import default_timer as now
import os

import h5py
import numpy as np

from fast5_research.fast5 import Fast5, iterate_fast5
from fast5_research.fast5_bulk import BulkFast5
from fast5_research.util import _sanitize_data_for_writing, readtsv
from fast5_research.util import _sanitize_data_for_writing, readtsv, group_vector


def extract_reads():
Expand Down Expand Up @@ -158,46 +159,101 @@ def filter_multi_reads():
else:
raise IOError('The output directory must not exist.')

#TODO: could attempt to discover filenames from args.
logger.info("Reading filter file.")
read_table = readtsv(args.filter, fields=[args.tsv_field])
required_reads = set(read_table[args.tsv_field])
logger.info("Found {} reads in filter.".format(len(required_reads)))

# grab list of source files
logger.info("Searching for input files.")
src_files = list(iterate_fast5(args.input, paths=True, recursive=args.recursive))
n_files = len(src_files)

logger.info("Finding reads within {} source files.".format(n_files))
index_worker = functools.partial(reads_in_multi, filt=required_reads)
read_index = dict()
n_reads = 0
with ProcessPoolExecutor(args.workers) as executor:
i = 0
for src_file, read_ids in zip(src_files, executor.map(index_worker, src_files, chunksize=10)):
i += 1
n_reads += len(read_ids)
read_index[src_file] = read_ids
if i % 10 == 0:
logger.info("Indexed {}/{} files. {}/{} reads".format(i, n_files, n_reads, len(required_reads)))

logger.info("Reading filter file.")
read_table = readtsv(args.filter, fields=[args.tsv_field])
logger.info("Found {} reads in filter.".format(len(read_table)))
try:
# try to build index from a file with 'filename' column
if 'filename' not in read_table.dtype.names:
raise ValueError("'filename' column not present in filter.")
logger.info("Attempting to build read index from input filter.")
src_path_files = {
os.path.basename(x):x for x in src_files
}
if len(src_path_files) != len(src_files):
raise ValueError('Found non-uniquely named source files')
read_index = dict()
for fname, indices in group_vector(read_table['filename']).items():
fpath = src_path_files[fname]
read_index[fpath] = read_table[args.tsv_field][indices]
logger.info("Successfully build read index from input filter.")
except Exception as e:
logger.info("Failed to build read index from summary: {}".format(e))
read_index = None
required_reads = set(read_table[args.tsv_field])
logger.info("Finding reads within {} source files.".format(n_files))
index_worker = functools.partial(reads_in_multi, filt=required_reads)
read_index = dict()
n_reads = 0
with ProcessPoolExecutor(args.workers) as executor:
i = 0
for src_file, read_ids in zip(src_files, executor.map(index_worker, src_files, chunksize=10)):
i += 1
n_reads += len(read_ids)
read_index[src_file] = read_ids
if i % 10 == 0:
logger.info("Indexed {}/{} files. {}/{} reads".format(i, n_files, n_reads, len(required_reads)))

n_reads = sum(len(x) for x in read_index.values())
# We don't go via creating Read objects, copying the data verbatim
# likely quicker and nothing should need the verification that the APIs
# provide (garbage in, garbage out).
logger.info("Extracting {} reads.".format(n_reads))
if args.prefix != '':
args.prefix = '{}_'.format(args.prefix)
with MultiWriter(args.output, None, prefix=args.prefix) as writer:

with ProcessPoolExecutor(args.workers) as executor:
reads_per_process = np.ceil(n_reads / args.workers)
proc_n_reads = 0
proc_reads = dict()
job = 0
futures = list()
for src in read_index.keys():
proc_reads[src] = read_index[src]
proc_n_reads += len(proc_reads[src])
if proc_n_reads > reads_per_process:
proc_prefix = "{}{}_".format(args.prefix, job)
futures.append(executor.submit(_subset_reads_to_file, proc_reads, args.output, proc_prefix, worker_id=job))
job += 1
proc_n_reads = 0
proc_reads = dict()
for fut in as_completed(futures):
try:
reads_written, prefix = fut.result()
logger.info("Written {} reads to {}.".format(reads_written, prefix))
except Exception as e:
logger.warning("Error: {}".format(e))
logger.info("Done.")


def _subset_reads_to_file(read_index, output, prefix, worker_id=0):
logger = logging.getLogger('Worker-{}'.format(worker_id))
n_reads = sum(len(x) for x in read_index.values())
reads_written = 0
t0 = now()
with MultiWriter(output, None, prefix=prefix) as writer:
for src_file, read_ids in read_index.items():
logger.info("Copying {} reads from {}.".format(
len(read_ids), os.path.basename(src_file)
))
reads_written += len(read_ids)
t1 = now()
if t1 - t0 > 30: # log update every 30 seconds
logger.info("Written {}/{} reads ({:.0f}% done)".format(
reads_written, n_reads, 100 * reads_written / n_reads
))
t0 = t1
with h5py.File(src_file, 'r') as src_fh:
for read_id in read_ids:
writer.write_read(src_fh["read_{}".format(read_id)])
logger.info("Done.")

try:
read_grp = src_fh["read_{}".format(read_id)]
except:
logger.warning("Did not find {} in {}.".format(read_id, src_fh.filename))
else:
writer.write_read(read_grp)
return reads_written, prefix

def reads_in_multi(src, filt=None):
"""Get list of read IDs contained within a multi-read file.
Expand Down
19 changes: 19 additions & 0 deletions fast5_research/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,3 +637,22 @@ def dtype_descr(arr):
return arr.dtype.descr
except ValueError:
return tuple([(n, arr.dtype[n].descr[0][1]) for n in arr.dtype.names])


def group_vector(arr):
"""Group a vector by unique values.
:param arr: input vector to be grouped.
:returns: a dictionary mapping unique values to arrays of indices of the
input vector.
"""
groups, keys_as_int = np.unique(arr, return_inverse = True)
n_keys = max(keys_as_int)
indices = [[] for i in range(n_keys + 1)]
for i, k in enumerate(keys_as_int):
indices[k].append(i)
indices = [np.array(elt) for elt in indices]
return dict(zip(groups, indices))

0 comments on commit 4a8fc87

Please sign in to comment.