Skip to content

Commit

Permalink
refactor: Factor out SelectionPredicateComposition
Browse files Browse the repository at this point in the history
Utilizes vega#3668
  • Loading branch information
dangotbanned committed Nov 4, 2024
1 parent a8e1ea1 commit f890feb
Showing 1 changed file with 21 additions and 30 deletions.
51 changes: 21 additions & 30 deletions altair/vegalite/v5/_api_rfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from altair.utils.core import TYPECODE_MAP as _TYPE_CODE
from altair.utils.core import parse_shorthand as _parse
from altair.utils.schemapi import Optional, SchemaBase, Undefined
from altair.vegalite.v5.api import Parameter, SelectionPredicateComposition
from altair.vegalite.v5.api import Parameter
from altair.vegalite.v5.schema import channels
from altair.vegalite.v5.schema._typing import (
BinnedTimeUnit_T,
Expand Down Expand Up @@ -93,10 +93,6 @@ def _parse_aggregate(
raise TypeError(msg)


def _wrap_composition(predicate: Predicate, /) -> SelectionPredicateComposition:
return SelectionPredicateComposition(predicate.to_dict())


def _one_of_flatten(
values: tuple[OneOfType, ...] | tuple[Sequence[OneOfType]] | tuple[Any, ...], /
) -> Sequence[OneOfType]:
Expand Down Expand Up @@ -305,7 +301,10 @@ class field:
{'field': 'Origin', 'type': 'nominal'}
>>> field.one_of("Origin", "Japan", "Europe")
SelectionPredicateComposition({'field': 'Origin', 'oneOf': ['Japan', 'Europe']})
FieldOneOfPredicate({
field: 'Origin',
oneOf: ('Japan', 'Europe')
})
"""

def __new__( # type: ignore[misc]
Expand All @@ -320,60 +319,52 @@ def one_of(
/,
*values: OneOfType | Sequence[OneOfType],
timeUnit: TimeUnitType = Undefined,
) -> SelectionPredicateComposition:
) -> Predicate:
seq = _one_of_flatten(values)
one_of = _one_of_variance(*seq)
p = FieldOneOfPredicate(field=field, oneOf=one_of, timeUnit=timeUnit)
return _wrap_composition(p)
return FieldOneOfPredicate(field=field, oneOf=one_of, timeUnit=timeUnit)

@classmethod
def eq(
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
) -> SelectionPredicateComposition:
p = FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit)
return _wrap_composition(p)
) -> Predicate:
return FieldEqualPredicate(field=field, equal=value, timeUnit=timeUnit)

@classmethod
def lt(
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
) -> SelectionPredicateComposition:
p = FieldLTPredicate(field=field, lt=value, timeUnit=timeUnit)
return _wrap_composition(p)
) -> Predicate:
return FieldLTPredicate(field=field, lt=value, timeUnit=timeUnit)

@classmethod
def lte(
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
) -> SelectionPredicateComposition:
p = FieldLTEPredicate(field=field, lte=value, timeUnit=timeUnit)
return _wrap_composition(p)
) -> Predicate:
return FieldLTEPredicate(field=field, lte=value, timeUnit=timeUnit)

@classmethod
def gt(
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
) -> SelectionPredicateComposition:
p = FieldGTPredicate(field=field, gt=value, timeUnit=timeUnit)
return _wrap_composition(p)
) -> Predicate:
return FieldGTPredicate(field=field, gt=value, timeUnit=timeUnit)

@classmethod
def gte(
cls, field: str, value: ValueType, /, *, timeUnit: TimeUnitType = Undefined
) -> SelectionPredicateComposition:
p = FieldGTEPredicate(field=field, gte=value, timeUnit=timeUnit)
return _wrap_composition(p)
) -> Predicate:
return FieldGTEPredicate(field=field, gte=value, timeUnit=timeUnit)

@classmethod
def valid(
cls, field: str, value: bool, /, *, timeUnit: TimeUnitType = Undefined
) -> SelectionPredicateComposition:
p = FieldValidPredicate(field=field, valid=value, timeUnit=timeUnit)
return _wrap_composition(p)
) -> Predicate:
return FieldValidPredicate(field=field, valid=value, timeUnit=timeUnit)

@classmethod
def range(
cls, field: str, value: RangeType, /, *, timeUnit: TimeUnitType = Undefined
) -> SelectionPredicateComposition:
p = FieldRangePredicate(field=field, range=value, timeUnit=timeUnit)
return _wrap_composition(p)
) -> Predicate:
return FieldRangePredicate(field=field, range=value, timeUnit=timeUnit)


# NOTE: Ignore everything below #
Expand Down

0 comments on commit f890feb

Please sign in to comment.