Skip to content

Commit

Permalink
Type hint schemapi.py
Browse files Browse the repository at this point in the history
  • Loading branch information
binste committed Aug 6, 2023
1 parent 0534853 commit 52c96a4
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 66 deletions.
143 changes: 110 additions & 33 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib
import inspect
import json
import sys
import textwrap
from typing import (
Any,
Expand All @@ -15,6 +16,11 @@
Tuple,
Iterable,
Type,
Generator,
Union,
overload,
Literal,
TypeVar,
)
from itertools import zip_longest

Expand All @@ -26,6 +32,13 @@

from altair import vegalite

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

_TSchemaBase = TypeVar("_TSchemaBase", bound="SchemaBase")

ValidationErrorList = List[jsonschema.exceptions.ValidationError]
GroupedValidationErrors = Dict[str, ValidationErrorList]

Expand All @@ -35,21 +48,21 @@
# larger specs, but leads to much more useful tracebacks for the user.
# Individual schema classes can override this by setting the
# class-level _class_is_valid_at_instantiation attribute to False
DEBUG_MODE = True
DEBUG_MODE: bool = True


def enable_debug_mode():
def enable_debug_mode() -> None:
global DEBUG_MODE
DEBUG_MODE = True


def disable_debug_mode():
def disable_debug_mode() -> None:
global DEBUG_MODE
DEBUG_MODE = False


@contextlib.contextmanager
def debug_mode(arg):
def debug_mode(arg: bool) -> Generator[None, None, None]:
global DEBUG_MODE
original = DEBUG_MODE
DEBUG_MODE = arg
Expand All @@ -59,12 +72,35 @@ def debug_mode(arg):
DEBUG_MODE = original


@overload
def validate_jsonschema(
spec: Dict[str, Any],
schema: Dict[str, Any],
rootschema: Optional[Dict[str, Any]] = None,
raise_error: bool = True,
rootschema: Optional[Dict[str, Any]] = ...,
*,
raise_error: Literal[True] = ...,
) -> None:
...


@overload
def validate_jsonschema(
spec: Dict[str, Any],
schema: Dict[str, Any],
rootschema: Optional[Dict[str, Any]] = ...,
*,
raise_error: Literal[False],
) -> Optional[jsonschema.exceptions.ValidationError]:
...


def validate_jsonschema(
spec,
schema,
rootschema=None,
*,
raise_error=True,
):
"""Validates the passed in spec against the schema in the context of the
rootschema. If any errors are found, they are deduplicated and prioritized
and only the most relevant errors are kept. Errors are then either raised
Expand All @@ -85,7 +121,7 @@ def validate_jsonschema(
# error message. Setting a new attribute like this is not ideal as
# it then no longer matches the type ValidationError. It would be better
# to refactor this function to never raise but only return errors.
main_error._all_errors = grouped_errors # type: ignore[attr-defined]
main_error._all_errors = grouped_errors
if raise_error:
raise main_error
else:
Expand Down Expand Up @@ -319,7 +355,7 @@ def _deduplicate_by_message(errors: ValidationErrorList) -> ValidationErrorList:
return list({e.message: e for e in errors}.values())


def _subclasses(cls):
def _subclasses(cls: type) -> Generator[type, None, None]:
"""Breadth-first sequence of all classes which inherit from cls."""
seen = set()
current_set = {cls}
Expand All @@ -330,7 +366,7 @@ def _subclasses(cls):
yield cls


def _todict(obj, context):
def _todict(obj: Any, context: Optional[Dict[str, Any]]) -> Any:
"""Convert an object to a dict representation."""
if isinstance(obj, SchemaBase):
return obj.to_dict(validate=False, context=context)
Expand All @@ -348,7 +384,7 @@ def _todict(obj, context):
return obj


def _resolve_references(schema, root=None):
def _resolve_references(schema: dict, root: Optional[dict] = None) -> dict:
"""Resolve schema references."""
resolver = jsonschema.RefResolver.from_schema(root or schema)
while "$ref" in schema:
Expand Down Expand Up @@ -597,9 +633,9 @@ class SchemaBase:

_schema: Optional[Dict[str, Any]] = None
_rootschema: Optional[Dict[str, Any]] = None
_class_is_valid_at_instantiation = True
_class_is_valid_at_instantiation: bool = True

def __init__(self, *args, **kwds):
def __init__(self, *args: Any, **kwds: Any) -> None:
# Two valid options for initialization, which should be handled by
# derived classes:
# - a single arg with no kwds, for, e.g. {'type': 'string'}
Expand All @@ -623,7 +659,9 @@ def __init__(self, *args, **kwds):
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)

def copy(self, deep=True, ignore=()):
def copy(
self, deep: Union[bool, Iterable] = True, ignore: Optional[list] = None
) -> Self:
"""Return a copy of the object
Parameters
Expand All @@ -648,7 +686,9 @@ def _shallow_copy(obj):
else:
return obj

def _deep_copy(obj, ignore=()):
def _deep_copy(obj, ignore: Optional[list] = None):
if ignore is None:
ignore = []
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
Expand All @@ -668,7 +708,7 @@ def _deep_copy(obj, ignore=()):
return obj

try:
deep = list(deep)
deep = list(deep) # type: ignore[arg-type]
except TypeError:
deep_is_list = False
else:
Expand All @@ -680,6 +720,8 @@ def _deep_copy(obj, ignore=()):
with debug_mode(False):
copy = self.__class__(*self._args, **self._kwds)
if deep_is_list:
# Assert statement is for the benefit of Mypy
assert isinstance(deep, list)
for attr in deep:
copy[attr] = _shallow_copy(copy._get(attr))
return copy
Expand Down Expand Up @@ -873,12 +915,19 @@ def to_json(
return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs)

@classmethod
def _default_wrapper_classes(cls):
def _default_wrapper_classes(cls) -> Generator[Type["SchemaBase"], None, None]:
"""Return the set of classes used within cls.from_dict()"""
return _subclasses(SchemaBase)

@classmethod
def from_dict(cls, dct, validate=True, _wrapper_classes=None):
def from_dict(
cls,
dct: dict,
validate: bool = True,
_wrapper_classes: Optional[Iterable[Type["SchemaBase"]]] = None,
# Type hints for this method would get rather complicated
# if we want to provide a more specific return type
) -> Any:
"""Construct class from a dictionary representation
Parameters
Expand All @@ -887,7 +936,7 @@ def from_dict(cls, dct, validate=True, _wrapper_classes=None):
The dict from which to construct the class
validate : boolean
If True (default), then validate the input against the schema.
_wrapper_classes : list (optional)
_wrapper_classes : iterable (optional)
The set of SchemaBase classes to use when constructing wrappers
of the dict inputs. If not specified, the result of
cls._default_wrapper_classes will be used.
Expand All @@ -910,7 +959,11 @@ def from_dict(cls, dct, validate=True, _wrapper_classes=None):
return converter.from_dict(dct, cls)

@classmethod
def from_json(cls, json_string, validate=True, **kwargs):
def from_json(
cls, json_string: str, validate: bool = True, **kwargs: Any
# Type hints for this method would get rather complicated
# if we want to provide a more specific return type
) -> Any:
"""Instantiate the object from a valid JSON string
Parameters
Expand All @@ -931,27 +984,36 @@ def from_json(cls, json_string, validate=True, **kwargs):
return cls.from_dict(dct, validate=validate)

@classmethod
def validate(cls, instance, schema=None):
def validate(
cls, instance: Dict[str, Any], schema: Optional[Dict[str, Any]] = None
) -> None:
"""
Validate the instance against the class schema in the context of the
rootschema.
"""
if schema is None:
schema = cls._schema
# For the benefit of mypy
assert schema is not None
return validate_jsonschema(
instance, schema, rootschema=cls._rootschema or cls._schema
)

@classmethod
def resolve_references(cls, schema=None):
def resolve_references(cls, schema: Optional[dict] = None) -> dict:
"""Resolve references in the context of this object's schema or root schema."""
schema_to_pass = schema or cls._schema
# For the benefit of mypy
assert schema_to_pass is not None
return _resolve_references(
schema=(schema or cls._schema),
schema=schema_to_pass,
root=(cls._rootschema or cls._schema or schema),
)

@classmethod
def validate_property(cls, name, value, schema=None):
def validate_property(
cls, name: str, value: Any, schema: Optional[dict] = None
) -> None:
"""
Validate a property against property schema in the context of the
rootschema
Expand All @@ -962,8 +1024,8 @@ def validate_property(cls, name, value, schema=None):
value, props.get(name, {}), rootschema=cls._rootschema or cls._schema
)

def __dir__(self):
return sorted(super().__dir__() + list(self._kwds.keys()))
def __dir__(self) -> list:
return sorted(list(super().__dir__()) + list(self._kwds.keys()))


def _passthrough(*args, **kwds):
Expand All @@ -980,7 +1042,7 @@ class _FromDict:

_hash_exclude_keys = ("definitions", "title", "description", "$schema", "id")

def __init__(self, class_list):
def __init__(self, class_list: Iterable[Type[SchemaBase]]) -> None:
# Create a mapping of a schema hash to a list of matching classes
# This lets us quickly determine the correct class to construct
self.class_dict = collections.defaultdict(list)
Expand All @@ -989,7 +1051,7 @@ def __init__(self, class_list):
self.class_dict[self.hash_schema(cls._schema)].append(cls)

@classmethod
def hash_schema(cls, schema, use_json=True):
def hash_schema(cls, schema: dict, use_json: bool = True) -> int:
"""
Compute a python hash for a nested dictionary which
properly handles dicts, lists, sets, and tuples.
Expand Down Expand Up @@ -1025,14 +1087,29 @@ def _freeze(val):
return hash(_freeze(schema))

def from_dict(
self, dct, cls=None, schema=None, rootschema=None, default_class=_passthrough
):
self,
dct: dict,
cls: Optional[Type[SchemaBase]] = None,
schema: Optional[dict] = None,
rootschema: Optional[dict] = None,
default_class=_passthrough,
# Type hints for this method would get rather complicated
# if we want to provide a more specific return type
) -> Any:
"""Construct an object from a dict representation"""
if (schema is None) == (cls is None):
raise ValueError("Must provide either cls or schema, but not both.")
if schema is None:
schema = schema or cls._schema
rootschema = rootschema or cls._rootschema
# Can ignore type errors as cls is not None in case schema is
schema = cls._schema # type: ignore[union-attr]
# For the benefit of mypy
assert schema is not None
if rootschema:
rootschema = rootschema
elif cls is not None and cls._rootschema is not None:
rootschema = cls._rootschema
else:
rootschema = None
rootschema = rootschema or schema

if isinstance(dct, SchemaBase):
Expand Down Expand Up @@ -1086,7 +1163,7 @@ def from_dict(


class _PropertySetter:
def __init__(self, prop, schema):
def __init__(self, prop: str, schema: dict) -> None:
self.prop = prop
self.schema = schema

Expand Down Expand Up @@ -1133,7 +1210,7 @@ def __call__(self, *args, **kwargs):
return obj


def with_property_setters(cls):
def with_property_setters(cls: _TSchemaBase) -> _TSchemaBase:
"""
Decorator to add property setters to a Schema class.
"""
Expand Down
Loading

0 comments on commit 52c96a4

Please sign in to comment.