Skip to content

Commit

Permalink
feat: Support Chart.transform_filter(*predicates, **constraints) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Nov 13, 2024
1 parent b292ccf commit 1d576a8
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 47 deletions.
146 changes: 110 additions & 36 deletions altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
NamedData,
ParameterName,
PointSelectionConfig,
Predicate,
PredicateComposition,
ProjectionType,
RepeatMapping,
Expand Down Expand Up @@ -542,12 +541,19 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool:
"""


_FieldEqualType: TypeAlias = Union[PrimitiveValue_T, Map, Parameter, SchemaBase]
"""Permitted types for equality checks on field values:
_FieldEqualType: TypeAlias = Union["IntoExpression", Parameter, SchemaBase]
"""
Permitted types for equality checks on field values.
Applies to the following context(s):
import altair as alt
- `datum.field == ...`
- `FieldEqualPredicate(equal=...)`
- `when(**constraints=...)`
alt.datum.field == ...
alt.FieldEqualPredicate(field="field", equal=...)
alt.when(field=...)
alt.when().then().when(field=...)
alt.Chart.transform_filter(field=...)
"""


Expand Down Expand Up @@ -2986,45 +2992,113 @@ def transform_extent(
"""
return self._add_transform(core.ExtentTransform(extent=extent, param=param))

# TODO: Update docstring
# # E.g. {'not': alt.FieldRangePredicate(field='year', range=[1950, 1960])}
def transform_filter(
self,
filter: str
| Expr
| Expression
| Predicate
| Parameter
| PredicateComposition
| dict[str, Predicate | str | list | bool],
**kwargs: Any,
predicate: Optional[_PredicateType] = Undefined,
*more_predicates: _ComposablePredicateType,
empty: Optional[bool] = Undefined,
**constraints: _FieldEqualType,
) -> Self:
"""
Add a :class:`FilterTransform` to the schema.
Add a :class:`FilterTransform` to the spec.
The resulting predicate is an ``&`` reduction over ``predicate`` and optional ``*``, ``**``, arguments.
Parameters
----------
filter : a filter expression or :class:`PredicateComposition`
The `filter` property must be one of the predicate definitions:
(1) a string or alt.expr expression
(2) a range predicate
(3) a selection predicate
(4) a logical operand combining (1)-(3)
(5) a Selection object
predicate
A selection or test predicate. ``str`` input will be treated as a test operand.
*more_predicates
Additional predicates, restricted to types supporting ``&``.
empty
For selection parameters, the predicate of empty selections returns ``True`` by default.
Override this behavior, with ``empty=False``.
Returns
-------
self : Chart object
returns chart to allow for chaining
.. note::
When ``predicate`` is a ``Parameter`` that is used more than once,
``self.transform_filter(..., empty=...)`` provides granular control for each occurrence.
**constraints
Specify `Field Equal Predicate`_'s.
Shortcut for ``alt.datum.field_name == value``, see examples for usage.
Warns
-----
AltairDeprecationWarning
If called using ``filter`` as a keyword argument.
See Also
--------
alt.when : Uses a similar syntax for defining conditional values.
Notes
-----
- Directly inspired by the syntax used in `polars.DataFrame.filter`_.
.. _Field Equal Predicate:
https://vega.github.io/vega-lite/docs/predicate.html#equal-predicate
.. _polars.DataFrame.filter:
https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.filter.html
Examples
--------
Setting up a common chart::
import altair as alt
from altair import datum
from vega_datasets import data
source = data.population.url
chart = (
alt.Chart(source)
.mark_line()
.encode(
x="age:O",
y="sum(people):Q",
color=alt.Color("year:O").legend(symbolType="square"),
)
)
chart
Singular predicates can be expressed via ``datum``::
chart.transform_filter(datum.year <= 1980)
We can also use selection parameters directly::
selection = alt.selection_point(encodings=["color"], bind="legend")
chart.transform_filter(selection).add_params(selection)
Or a field predicate::
between_1950_60 = alt.FieldRangePredicate(field="year", range=[1950, 1960])
chart.transform_filter(between_1950_60) | chart.transform_filter(~between_1950_60)
Predicates can be composed together using logical operands::
chart.transform_filter(between_1950_60 | (datum.year == 1850))
Predicates passed as positional arguments will be reduced with ``&``::
chart.transform_filter(datum.year > 1980, datum.age != 90)
Using keyword-argument ``constraints`` can simplify compositions like::
verbose_composition = chart.transform_filter((datum.year == 2000) & (datum.sex == 1))
chart.transform_filter(year=2000, sex=1)
"""
if isinstance(filter, Parameter):
new_filter: dict[str, Any] = {"param": filter.name}
if "empty" in kwargs:
new_filter["empty"] = kwargs.pop("empty")
elif isinstance(filter.empty, bool):
new_filter["empty"] = filter.empty
filter = new_filter
return self._add_transform(core.FilterTransform(filter=filter, **kwargs))
if depr_filter := t.cast(Any, constraints.pop("filter", None)):
utils.deprecated_warn(
"Passing `filter` as a keyword is ambiguous.\n\n"
"Use a positional argument for `<5.5.0` behavior.\n"
"Or, `alt.datum['filter'] == ...` if referring to a column named 'filter'.",
version="5.5.0",
)
if utils.is_undefined(predicate):
predicate = depr_filter
else:
more_predicates = *more_predicates, depr_filter
cond = _parse_when(predicate, *more_predicates, empty=empty, **constraints)
return self._add_transform(core.FilterTransform(filter=cond.get("test", cond)))

def transform_flatten(
self,
Expand Down
20 changes: 18 additions & 2 deletions doc/user_guide/transform/filter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ expressions and objects:

We'll show a brief example of each of these in the following sections

.. _filter-expression:

Filter Expression
^^^^^^^^^^^^^^^^^
A filter expression uses the `Vega expression`_ language, either specified
Expand Down Expand Up @@ -189,12 +191,26 @@ Then, we can *invert* this selection using ``~``:
chart.transform_filter(~between_1950_60)

We can further refine our filter by *composing* multiple predicates together.
In this case, using ``alt.datum``:
In this case, using ``datum``:

.. altair-plot::

chart.transform_filter(~between_1950_60 & (datum.age <= 70))

When passing multiple predicates they will be reduced with ``&``:

.. altair-plot::

chart.transform_filter(~between_1950_60 & (alt.datum.age <= 70))
chart.transform_filter(datum.year > 1980, datum.age != 90)

Using keyword-argument ``constraints`` can simplify our first example in :ref:`filter-expression`:

.. altair-plot::

alt.Chart(source).mark_area().encode(
x="age:O",
y="people:Q",
).transform_filter(year=2000, sex=1)

Transform Options
^^^^^^^^^^^^^^^^^
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
columns_sorted = ['Drought', 'Epidemic', 'Earthquake', 'Flood']

alt.Chart(source).transform_filter(
{'and': [
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), # Filter data to show only disasters in columns_sorted
alt.FieldRangePredicate(field='Year', range=[1900, 2000]) # Filter data to show only 20th century
]}
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted),
alt.FieldRangePredicate(field='Year', range=[1900, 2000])
).transform_window(
cumulative_deaths='sum(Deaths)', groupby=['Entity'] # Calculate cumulative sum of Deaths by Entity
).mark_line().encode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
columns_sorted = ['Drought', 'Epidemic', 'Earthquake', 'Flood']

alt.Chart(source).transform_filter(
{'and': [
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted), # Filter data to show only disasters in columns_sorted
alt.FieldRangePredicate(field='Year', range=[1900, 2000]) # Filter data to show only 20th century
]}
alt.FieldOneOfPredicate(field='Entity', oneOf=columns_sorted),
alt.FieldRangePredicate(field='Year', range=[1900, 2000])
).transform_window(
cumulative_deaths='sum(Deaths)', groupby=['Entity'] # Calculate cumulative sum of Deaths by Entity
).mark_line().encode(
Expand Down
61 changes: 60 additions & 1 deletion tests/vegalite/v5/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import re
import sys
import tempfile
import warnings
from collections.abc import Mapping
from datetime import date, datetime
from importlib.metadata import version as importlib_version
Expand Down Expand Up @@ -85,7 +86,7 @@ def _make_chart_type(chart_type):


@pytest.fixture
def basic_chart():
def basic_chart() -> alt.Chart:
data = pd.DataFrame(
{
"a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
Expand Down Expand Up @@ -1247,6 +1248,64 @@ def test_predicate_composition() -> None:
assert actual_multi == expected_multi


def test_filter_transform_predicates(basic_chart) -> None:
lhs, rhs = alt.datum["b"] >= 30, alt.datum["b"] < 60
expected = [{"filter": lhs & rhs}]
actual = basic_chart.transform_filter(lhs, rhs).to_dict()["transform"]
assert actual == expected


def test_filter_transform_constraints(basic_chart) -> None:
lhs, rhs = alt.datum["a"] == "A", alt.datum["b"] == 30
expected = [{"filter": lhs & rhs}]
actual = basic_chart.transform_filter(a="A", b=30).to_dict()["transform"]
assert actual == expected


def test_filter_transform_predicates_constraints(basic_chart) -> None:
from functools import reduce
from operator import and_

predicates = (
alt.datum["a"] != "A",
alt.datum["a"] != "B",
alt.datum["a"] != "C",
alt.datum["b"] > 1,
alt.datum["b"] < 99,
)
constraints = {"b": 30, "a": "D"}
pred_constraints = *predicates, alt.datum["b"] == 30, alt.datum["a"] != "D"
expected = [{"filter": reduce(and_, pred_constraints)}]
actual = basic_chart.transform_filter(*predicates, **constraints).to_dict()[
"transform"
]
assert actual == expected


def test_filter_transform_errors(basic_chart) -> None:
NO_ARGS = r"At least one.+Undefined"
FILTER_KWARGS = r"ambiguous"

depr_filter = {"field": "year", "oneOf": [1955, 2000]}
expected = [{"filter": depr_filter}]

with pytest.raises(TypeError, match=NO_ARGS):
basic_chart.transform_filter()
with pytest.raises(TypeError, match=NO_ARGS):
basic_chart.transform_filter(empty=True)
with pytest.raises(TypeError, match=NO_ARGS):
basic_chart.transform_filter(empty=False)

with pytest.warns(alt.AltairDeprecationWarning, match=FILTER_KWARGS):
basic_chart.transform_filter(filter=depr_filter)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=alt.AltairDeprecationWarning)
actual = basic_chart.transform_filter(filter=depr_filter).to_dict()["transform"]

assert actual == expected


def test_resolve_methods():
chart = alt.LayerChart().resolve_axis(x="shared", y="independent")
assert chart.resolve == alt.Resolve(
Expand Down

0 comments on commit 1d576a8

Please sign in to comment.