Skip to content

Commit

Permalink
feat(DRAFT): Adds infer-based altair.datasets.load
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Nov 17, 2024
1 parent dc4a230 commit 7ddb2a8
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 13 deletions.
35 changes: 23 additions & 12 deletions altair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from altair.datasets._readers import _Backend
from altair.datasets._typing import Dataset, Extension, Version

__all__ = ["Loader", "data"]
__all__ = ["Loader", "load"]


class Loader(Generic[IntoDataFrameT, IntoFrameT]):
Expand Down Expand Up @@ -320,18 +320,29 @@ def __repr__(self) -> str:
return f"{type(self).__name__}[{self._reader._name}]"


load: Loader[Any, Any]


def __getattr__(name):
if name == "data":
global data
data = Loader.with_backend("pandas")
from altair.utils.deprecation import deprecated_warn

deprecated_warn(
"Added only for backwards compatibility with `altair-viz/vega_datasets`.",
version="5.5.0",
alternative="altair.datasets.Loader.with_backend(...)",
if name == "load":
import warnings

from altair.datasets._readers import infer_backend

reader = infer_backend()
global load
load = Loader.__new__(Loader)
load._reader = reader

warnings.warn(
"For full IDE completions, instead use:\n\n"
" from altair.datasets import Loader\n"
" load = Loader.with_backend(...)\n\n"
"Related: https://github.com/vega/altair/pull/3631#issuecomment-2480832609",
UserWarning,
stacklevel=3,
)
return data
return load
else:
raise AttributeError(name)
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
32 changes: 31 additions & 1 deletion altair/datasets/_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import os
import urllib.request
from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from importlib import import_module
from importlib.util import find_spec
Expand Down Expand Up @@ -475,6 +475,36 @@ def is_ext_scan(suffix: Any) -> TypeIs[_ExtensionScan]:
return suffix == ".parquet"


def is_available(pkg_names: str | Iterable[str], *more_pkg_names: str) -> bool:
pkgs_names = pkg_names if not isinstance(pkg_names, str) else (pkg_names,)
names = chain(pkgs_names, more_pkg_names)
return all(find_spec(name) is not None for name in names)


def infer_backend(
*, priority: Sequence[_Backend] = ("polars", "pandas[pyarrow]", "pandas", "pyarrow")
) -> _Reader[Any, Any]:
"""
Return the first available reader in order of `priority`.
Notes
-----
- ``"polars"``: can natively load every dataset (including ``(Geo|Topo)JSON``)
- ``"pandas[pyarrow]"``: can load *most* datasets, guarantees ``.parquet`` support
- ``"pandas"``: supports ``.parquet``, if `fastparquet`_ is installed
- ``"pyarrow"``: least reliable
.. _fastparquet:
https://github.com/dask/fastparquet
"""
it = (backend(name) for name in priority if is_available(_requirements(name)))
if reader := next(it, None):
return reader
msg = f"Found no supported backend, searched:\n" f"{priority!r}"
raise NotImplementedError(msg)


@overload
def backend(name: _PolarsAny, /) -> _Reader[pl.DataFrame, pl.LazyFrame]: ...

Expand Down
54 changes: 54 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import datetime as dt
import re
import sys
import warnings
from functools import partial
from importlib import import_module
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, cast, get_args
from urllib.error import URLError
Expand Down Expand Up @@ -127,6 +129,58 @@ def test_loader_url(backend: _Backend) -> None:
assert pattern.match(url) is not None


def test_load(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Inferring the best backend available.
Based on the following order:
priority: Sequence[_Backend] = "polars", "pandas[pyarrow]", "pandas", "pyarrow"
"""
import altair.datasets

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
from altair.datasets import load

assert load._reader._name == "polars"
monkeypatch.delattr(altair.datasets, "load")

monkeypatch.setitem(sys.modules, "polars", None)

from altair.datasets import load

if find_spec("pyarrow") is None:
# NOTE: We can end the test early for the CI job that removes `pyarrow`
assert load._reader._name == "pandas"
monkeypatch.delattr(altair.datasets, "load")
monkeypatch.setitem(sys.modules, "pandas", None)
with pytest.raises(NotImplementedError, match="no.+backend"):
from altair.datasets import load
else:
assert load._reader._name == "pandas[pyarrow]"
monkeypatch.delattr(altair.datasets, "load")

monkeypatch.setitem(sys.modules, "pyarrow", None)

from altair.datasets import load

assert load._reader._name == "pandas"
monkeypatch.delattr(altair.datasets, "load")

monkeypatch.setitem(sys.modules, "pandas", None)
monkeypatch.delitem(sys.modules, "pyarrow")
monkeypatch.setitem(sys.modules, "pyarrow", import_module("pyarrow"))
from altair.datasets import load

assert load._reader._name == "pyarrow"
monkeypatch.delattr(altair.datasets, "load")
monkeypatch.setitem(sys.modules, "pyarrow", None)

with pytest.raises(NotImplementedError, match="no.+backend"):
from altair.datasets import load


@backends
def test_loader_call(backend: _Backend, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv(CACHE_ENV_VAR, raising=False)
Expand Down

0 comments on commit 7ddb2a8

Please sign in to comment.