Skip to content

Commit

Permalink
🚧 filter: Add --query-sqlite option
Browse files Browse the repository at this point in the history
This adds a new flag to query the SQLite database natively.
`--query`/`--query-pandas` will still behave as expected.

All Pandas-based query functions are renamed to be Pandas-specific.

To avoid breaking changes, alias `--query` to `--query-pandas`.
  • Loading branch information
victorlin committed Feb 3, 2024
1 parent 0bbad38 commit 3cf5670
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 34 deletions.
10 changes: 8 additions & 2 deletions augur/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,23 @@ def register_arguments(parser):
input_group.add_argument('--metadata', required=True, metavar="FILE", help="sequence metadata")
input_group.add_argument('--sequences', '-s', help="sequences in FASTA or VCF format")
input_group.add_argument('--sequence-index', help="sequence composition report generated by augur index. If not provided, an index will be created on the fly.")
input_group.add_argument('--metadata-chunk-size', type=int, default=100000, help="maximum number of metadata records to read into memory at a time. Increasing this number can speed up filtering at the cost of more memory used.")
input_group.add_argument('--metadata-chunk-size', type=int, default=100000, help="maximum number of metadata records to read into memory at a time. Increasing this number can speed up filtering at the cost of more memory used. NOTE: this only applies to --query/--query-pandas.")
input_group.add_argument('--metadata-id-columns', default=DEFAULT_ID_COLUMNS, nargs="+", help="names of possible metadata columns containing identifier information, ordered by priority. Only one ID column will be inferred.")
input_group.add_argument('--metadata-delimiters', default=DEFAULT_DELIMITERS, nargs="+", help="delimiters to accept when reading a metadata file. Only one delimiter will be inferred.")

metadata_filter_group = parser.add_argument_group("metadata filters", "filters to apply to metadata")
metadata_filter_group.add_argument(
'--query',
'--query-pandas', '--query',
help="""Filter samples by attribute.
Uses Pandas Dataframe querying, see https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#indexing-query for syntax.
(e.g., --query "country == 'Colombia'" or --query "(country == 'USA' & (division == 'Washington'))")"""
)
metadata_filter_group.add_argument(
'--query-sqlite',
help="""Filter samples by attribute.
Uses SQL WHERE clause querying, see https://www.sqlite.org/lang_expr.html for syntax.
(e.g., --query "country = 'Colombia'" or --query "(country = 'USA' AND division = 'Washington')")"""
)
metadata_filter_group.add_argument('--query-columns', type=column_type_pair, nargs="+", help=f"""
Use alongside --query to specify columns and data types in the format 'column:type', where type is one of ({','.join(ACCEPTED_TYPES)}).
Automatic type inference will be attempted on all unspecified columns used in the query.
Expand Down
144 changes: 133 additions & 11 deletions augur/filter/include_exclude_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,66 @@ def filter_by_exclude_where(exclude_where) -> FilterFunctionReturn:
return expression, parameters


def filter_by_query(query: str, chunksize: int, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn:
def filter_by_sqlite_query(query: str, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn:
"""Filter by any valid SQLite expression on the metadata.
Strains that do *not* match the query will be excluded.
Parameters
----------
query : str
SQL expression used to exclude strains
column_types : str
Dict mapping of data type
"""
with Sqlite3Database(constants.RUNTIME_DB_FILE) as db:
metadata_id_column = db.get_primary_index(constants.METADATA_TABLE)
metadata_columns = set(db.columns(constants.METADATA_TABLE))

if column_types is None:
column_types = {}

# Set columns for type conversion.
variables = extract_potential_sqlite_variables(query)
if variables is not None:
columns = variables.intersection(metadata_columns)
else:
# Column extraction failed. Apply type conversion to all columns.
columns = metadata_columns

# If a type is not explicitly provided, try converting the column to numeric.
# This should cover most use cases, since one common problem is that the
# built-in data type inference when loading the DataFrame does not
# support nullable numeric columns, so numeric comparisons won't work on
# those columns. pd.to_numeric does proper conversion on those columns,
# and will not make any changes to columns with other values.
for column in columns:
column_types.setdefault(column, 'numeric')

# FIXME: Apply column_types.
# It's not easy to change the type on the table schema.¹
# Maybe using CAST? But that always takes place even if the conversion is lossy
# and irreversible (i.e. no error handling options like pd.to_numeric).
# ¹ <https://www.sqlite.org/lang_altertable.html#making_other_kinds_of_table_schema_changes>
# ² <https://www.sqlite.org/lang_expr.html#castexpr>

expression = f"""
{constants.ID_COLUMN} IN (
SELECT {sanitize_identifier(metadata_id_column)}
FROM {constants.METADATA_TABLE}
WHERE NOT ({query})
)
"""
parameters: SqlParameters = {}
return expression, parameters


def filter_by_pandas_query(query: str, chunksize: int, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn:
"""Filter by a Pandas expression on the metadata.
Note that this is inefficient compared to native SQLite queries, and is in place
for backwards compatibility.
Parameters
----------
query : str
Expand All @@ -170,7 +227,7 @@ def filter_by_query(query: str, chunksize: int, column_types: Optional[Dict[str,
column_types = {}

# Set columns for type conversion.
variables = extract_variables(query)
variables = extract_pandas_query_variables(query)
if variables is not None:
columns = variables.intersection(metadata_columns)
else:
Expand Down Expand Up @@ -525,17 +582,26 @@ def construct_filters(args) -> Tuple[List[FilterOption], List[FilterOption]]:
{"exclude_where": exclude_where}
))

# Exclude strains by metadata, using pandas querying.
if args.query:
# Exclude strains by metadata.
if args.query_pandas:
kwargs = {
"query": args.query,
"query": args.query_pandas,
"chunksize": args.metadata_chunk_size,
}
if args.query_columns:
kwargs["column_types"] = {column: dtype for column, dtype in args.query_columns}

exclude_by.append((
filter_by_query,
filter_by_pandas_query,
kwargs
))
if args.query_sqlite:
kwargs = {"query": args.query_sqlite}
if args.query_columns:
kwargs["column_types"] = {column: dtype for column, dtype in args.query_columns}

exclude_by.append((
filter_by_sqlite_query,
kwargs
))

Expand Down Expand Up @@ -757,20 +823,20 @@ def filter_kwargs_to_str(kwargs: FilterFunctionKwargs):
return json.dumps(kwarg_list)


def extract_variables(pandas_query: str):
def extract_pandas_query_variables(pandas_query: str):
"""Try extracting all variable names used in a pandas query string.
If successful, return the variable names as a set. Otherwise, nothing is returned.
Examples
--------
>>> extract_variables("var1 == 'value'")
>>> extract_pandas_query_variables("var1 == 'value'")
{'var1'}
>>> sorted(extract_variables("var1 == 'value' & var2 == 10"))
>>> sorted(extract_pandas_query_variables("var1 == 'value' & var2 == 10"))
['var1', 'var2']
>>> extract_variables("var1.str.startswith('prefix')")
>>> extract_pandas_query_variables("var1.str.startswith('prefix')")
{'var1'}
>>> extract_variables("this query is invalid")
>>> extract_pandas_query_variables("this query is invalid")
"""
# Since Pandas' query grammar should be a subset of Python's, which uses the
# ast stdlib under the hood, we can try to parse queries with that as well.
Expand All @@ -783,3 +849,59 @@ def extract_variables(pandas_query: str):
if isinstance(node, ast.Name))
except:
return None


def extract_potential_sqlite_variables(sqlite_expression: str):
"""Try extracting all variable names used in a SQLite expression.
If successful, return the variable names as a set. Otherwise, nothing is returned.
Examples
--------
>>> extract_potential_sqlite_variables("var1 = 'value'")
{'var1'}
>>> sorted(extract_potential_sqlite_variables("var1 = 'value' AND var2 = 10"))
['var1', 'var2']
>>> extract_potential_sqlite_variables("var1 LIKE 'prefix%'")
{'var1'}
>>> sorted(extract_potential_sqlite_variables("this query is invalid"))
['invalid', 'this query']
"""
# This seems to be more difficult than Pandas query parsing.
# <https://stackoverflow.com/q/35624662>
try:
query = f"SELECT * FROM table WHERE {sqlite_expression}"
where = [x for x in sqlparse.parse(query)[0] if isinstance(x, sqlparse.sql.Where)][0]
variables = set(_get_identifiers(where)) or None
return variables
except:
return None


def _get_identifiers(token: sqlparse.sql.Token):
"""Yield identifiers from a token's children.
Inspired by ast.walk.
"""
from collections import deque
todo = deque([token])
while todo:
node = todo.popleft()

# Limit to comparisons to avoid false positives.
# I chose not to use this because it also comes with false negatives.
#
# if isinstance(node, sqlparse.sql.Comparison):
# if isinstance(node.left, sqlparse.sql.Identifier):
# yield str(node.left)
# elif hasattr(node.left, 'tokens'):
# todo.extend(node.left.tokens)
# if isinstance(node.right, sqlparse.sql.Identifier):
# yield str(node.right)
# elif hasattr(node.right, 'tokens'):
# todo.extend(node.right.tokens)

if isinstance(node, sqlparse.sql.Identifier):
yield str(node)
elif hasattr(node, 'tokens'):
todo.extend(node.tokens)
19 changes: 16 additions & 3 deletions augur/filter/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from augur.io.file import open_file
from augur.io.metadata import Metadata, METADATA_DATE_COLUMN
from augur.filter.include_exclude_rules import extract_variables, parse_filter_query
from augur.filter.include_exclude_rules import extract_pandas_query_variables, extract_potential_sqlite_variables, parse_filter_query
from augur.filter.debug import add_debugging
from augur.io.print import print_err
from augur.io.sequences import read_sequences, write_sequences
Expand Down Expand Up @@ -67,18 +67,31 @@ def get_useful_metadata_columns(args: Namespace, id_column: str, all_columns: Se
columns.add(column)

# Add columns used in Pandas queries.
if args.query:
if args.query_pandas:
if args.query_columns:
# Use column names explicitly specified by the user.
for column, dtype in args.query_columns:
columns.add(column)

# Attempt to automatically extract columns from the query.
variables = extract_variables(args.query)
variables = extract_pandas_query_variables(args.query_pandas)
if variables is None and not args.query_columns:
raise AugurError("Could not infer columns from the pandas query. If the query is valid, please specify columns using --query-columns.")
else:
columns.update(variables)

if args.query_sqlite:
if args.query_columns:
# Use column names explicitly specified by the user.
for column, dtype in args.query_columns:
columns.add(column)

# Attempt to automatically extract columns from the query.
variables = extract_potential_sqlite_variables(args.query_sqlite)
if variables is None and not args.query_columns:
raise AugurError("Could not infer columns from the SQLite query. If the query is valid, please specify columns using --query-columns.")
else:
columns.update(variables)

return list(columns)

Expand Down
3 changes: 2 additions & 1 deletion augur/filter/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def print_report(args: Namespace, exclude_by: List[FilterOption], include_by: Li
include_exclude_rules.filter_by_exclude_all.__name__: "{count} of these were dropped by `--exclude-all`",
include_exclude_rules.filter_by_exclude.__name__: "{count} of these were dropped because they were in {exclude_file}",
include_exclude_rules.filter_by_exclude_where.__name__: "{count} of these were dropped because of '{exclude_where}'",
include_exclude_rules.filter_by_query.__name__: "{count} of these were filtered out by the query: \"{query}\"",
include_exclude_rules.filter_by_sqlite_query.__name__: "{count} of these were filtered out by the SQLite query: \"{query}\"",
include_exclude_rules.filter_by_pandas_query.__name__: "{count} of these were filtered out by the Pandas query: \"{query}\"",
include_exclude_rules.filter_by_ambiguous_date.__name__: "{count} of these were dropped because of their ambiguous date in {ambiguity}",
include_exclude_rules.filter_by_min_date.__name__: "{count} of these were dropped because they were earlier than {min_date} or missing a date",
include_exclude_rules.filter_by_max_date.__name__: "{count} of these were dropped because they were later than {max_date} or missing a date",
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ ignore_missing_imports = True

[mypy-scipy.*]
ignore_missing_imports = True

[mypy-sqlparse.*]
ignore_missing_imports = True
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"phylo-treetime >=0.11.2, <0.12",
"pyfastx >=1.0.0, <3.0",
"scipy ==1.*",
"sqlparse ==0.4.*",
"xopen[zstd] >=1.7.0, ==1.*"
],
extras_require = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ The query initially filters 3 strains from Colombia, one of which is added back
4 strains were dropped during filtering
\t1 had no metadata (esc)
\t1 had no sequence data (esc)
\t2 of these were filtered out by the query: "country != 'Colombia'" (esc)
\t2 of these were filtered out by the Pandas query: "country != 'Colombia'" (esc)
\\t1 strains were force-included because they were in .*include\.txt.* (re)
9 strains passed all filters

$ head -n 1 filtered_log.tsv; tail -n +2 filtered_log.tsv | sort -k 1,1
strain filter kwargs
COL/FLR_00008/2015\tforce_include_strains\t"[[""include_file"", ""*/data/include.txt""]]" (esc) (glob)
COL/FLR_00024/2015 filter_by_query "[[""query"", ""country != 'Colombia'""]]"
Colombia/2016/ZC204Se filter_by_query "[[""query"", ""country != 'Colombia'""]]"
COL/FLR_00024/2015 filter_by_pandas_query "[[""query"", ""country != 'Colombia'""]]"
Colombia/2016/ZC204Se filter_by_pandas_query "[[""query"", ""country != 'Colombia'""]]"
HND/2016/HU_ME59 filter_by_sequence_index []
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ Confirm that `--exclude-ambiguous-dates-by` works for all year only ambiguous da
> --empty-output-reporting silent \
> --output-strains filtered_strains.txt
4 strains were dropped during filtering
\t1 of these were filtered out by the query: "region=="Asia"" (esc)
\t1 of these were filtered out by the Pandas query: "region=="Asia"" (esc)
\t3 of these were dropped because of their ambiguous date in any (esc)
0 strains passed all filters
6 changes: 3 additions & 3 deletions tests/functional/filter/cram/filter-query-columns.t
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Automatic inference works.
> --query "coverage >= 0.95 & category == 'B'" \
> --output-strains filtered_strains.txt
3 strains were dropped during filtering
3 of these were filtered out by the query: "coverage >= 0.95 & category == 'B'"
3 of these were filtered out by the Pandas query: "coverage >= 0.95 & category == 'B'"
1 strains passed all filters

Specifying coverage:float explicitly also works.
Expand All @@ -30,7 +30,7 @@ Specifying coverage:float explicitly also works.
> --query-columns coverage:float \
> --output-strains filtered_strains.txt
3 strains were dropped during filtering
3 of these were filtered out by the query: "coverage >= 0.95 & category == 'B'"
3 of these were filtered out by the Pandas query: "coverage >= 0.95 & category == 'B'"
1 strains passed all filters

Specifying coverage:float category:str also works.
Expand All @@ -41,7 +41,7 @@ Specifying coverage:float category:str also works.
> --query-columns coverage:float category:str \
> --output-strains filtered_strains.txt
3 strains were dropped during filtering
3 of these were filtered out by the query: "coverage >= 0.95 & category == 'B'"
3 of these were filtered out by the Pandas query: "coverage >= 0.95 & category == 'B'"
1 strains passed all filters

Specifying category:float does not work.
Expand Down
Loading

0 comments on commit 3cf5670

Please sign in to comment.