Skip to content

Commit

Permalink
Merge branch 'main' into refac-alt-theme
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored Nov 5, 2024
2 parents fd4a139 + 085a6b6 commit cf4a043
Show file tree
Hide file tree
Showing 20 changed files with 212 additions and 163 deletions.
35 changes: 17 additions & 18 deletions altair/_magics.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
"""Magic functions for rendering vega-lite specifications."""

__all__ = ["vegalite"]
from __future__ import annotations

import json
import warnings
from importlib.util import find_spec
from typing import Any

import IPython
from IPython.core import magic_arguments
from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe

from altair.vegalite import v5 as vegalite_v5

try:
import yaml

YAML_AVAILABLE = True
except ImportError:
YAML_AVAILABLE = False

__all__ = ["vegalite"]

RENDERERS = {
"vega-lite": {
Expand Down Expand Up @@ -48,19 +43,21 @@ def _prepare_data(data, data_transformers):
return data


def _get_variable(name):
def _get_variable(name: str) -> Any:
"""Get a variable from the notebook namespace."""
ip = IPython.get_ipython()
if ip is None:
from IPython.core.getipython import get_ipython

if ip := get_ipython():
if name not in ip.user_ns:
msg = f"argument '{name}' does not match the name of any defined variable"
raise NameError(msg)
return ip.user_ns[name]
else:
msg = (
"Magic command must be run within an IPython "
"environment, in which get_ipython() is defined."
)
raise ValueError(msg)
if name not in ip.user_ns:
msg = f"argument '{name}' does not match the name of any defined variable"
raise NameError(msg)
return ip.user_ns[name]


@magic_arguments.magic_arguments()
Expand All @@ -71,7 +68,7 @@ def _get_variable(name):
)
@magic_arguments.argument("-v", "--version", dest="version", default="v5")
@magic_arguments.argument("-j", "--json", dest="json", action="store_true")
def vegalite(line, cell):
def vegalite(line, cell) -> vegalite_v5.VegaLite:
"""
Cell magic for displaying vega-lite visualizations in CoLab.
Expand All @@ -91,7 +88,7 @@ def vegalite(line, cell):

if args.json:
spec = json.loads(cell)
elif not YAML_AVAILABLE:
elif not find_spec("yaml"):
try:
spec = json.loads(cell)
except json.JSONDecodeError as err:
Expand All @@ -101,6 +98,8 @@ def vegalite(line, cell):
)
raise ValueError(msg) from err
else:
import yaml

spec = yaml.load(cell, Loader=yaml.SafeLoader)

if args.data is not None:
Expand Down
6 changes: 6 additions & 0 deletions altair/utils/_dfi_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def dtype(self) -> tuple[Any, int, str, str]:
- Data types not included: complex, Arrow-style null, binary, decimal,
and nested (list, struct, map, union) dtypes.
"""
...

# Have to use a generic Any return type as not all libraries who implement
# the dataframe interchange protocol implement the TypedDict that is usually
Expand Down Expand Up @@ -106,6 +107,7 @@ def describe_categorical(self) -> Any:
TBD: are there any other in-memory representations that are needed?
"""
...


class DataFrame(Protocol):
Expand Down Expand Up @@ -137,12 +139,15 @@ def __dataframe__(
necessary if a library supports strided buffers, given that this protocol
specifies contiguous buffers.
"""
...

def column_names(self) -> Iterable[str]:
"""Return an iterator yielding the column names."""
...

def get_column_by_name(self, name: str) -> Column:
"""Return the column whose name is the indicated name."""
...

def get_chunks(self, n_chunks: int | None = None) -> Iterable[DataFrame]:
"""
Expand All @@ -156,3 +161,4 @@ def get_chunks(self, n_chunks: int | None = None) -> Iterable[DataFrame]:
Note that the producer must ensure that all columns are chunked the
same way.
"""
...
10 changes: 5 additions & 5 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __dataframe__(


def infer_vegalite_type_for_pandas(
data: object,
data: Any,
) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list[Any]]:
"""
From an array-like input, infer the correct vega typecode.
Expand All @@ -231,7 +231,7 @@ def infer_vegalite_type_for_pandas(
Parameters
----------
data: object
data: Any
"""
# This is safe to import here, as this function is only called on pandas input.
from pandas.api.types import infer_dtype
Expand Down Expand Up @@ -738,10 +738,10 @@ def use_signature(tp: Callable[P, Any], /):
"""

@overload
def decorate(cb: WrapsMethod[T, R], /) -> WrappedMethod[T, P, R]: ...
def decorate(cb: WrapsMethod[T, R], /) -> WrappedMethod[T, P, R]: ... # pyright: ignore[reportOverlappingOverload]

@overload
def decorate(cb: WrapsFunc[R], /) -> WrappedFunc[P, R]: ...
def decorate(cb: WrapsFunc[R], /) -> WrappedFunc[P, R]: ... # pyright: ignore[reportOverlappingOverload]

def decorate(cb: WrapsFunc[R], /) -> WrappedMethod[T, P, R] | WrappedFunc[P, R]:
"""
Expand Down Expand Up @@ -857,7 +857,7 @@ def from_cache(cls) -> _ChannelCache:
cached = _CHANNEL_CACHE
except NameError:
cached = cls.__new__(cls)
cached.channel_to_name = _init_channel_to_name()
cached.channel_to_name = _init_channel_to_name() # pyright: ignore[reportAttributeAccessIssue]
cached.name_to_channel = _invert_group_channels(cached.channel_to_name)
_CHANNEL_CACHE = cached
return _CHANNEL_CACHE
Expand Down
8 changes: 3 additions & 5 deletions altair/vegalite/v5/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ def browser_renderer(
vegalite_version=VEGALITE_VERSION,
**metadata,
)

if isinstance(mimebundle, tuple):
mimebundle = mimebundle[0]

html = mimebundle["text/html"]
open_html_in_browser(html, using=using, port=port)
return {}
Expand Down Expand Up @@ -162,7 +158,9 @@ def browser_renderer(
renderers.register("json", json_renderer)
renderers.register("png", png_renderer)
renderers.register("svg", svg_renderer)
renderers.register("jupyter", jupyter_renderer)
# FIXME: Caused by upstream # type: ignore[unreachable]
# https://github.com/manzt/anywidget/blob/b7961305a7304f4d3def1fafef0df65db56cf41e/anywidget/widget.py#L80-L81
renderers.register("jupyter", jupyter_renderer) # pyright: ignore[reportArgumentType]
renderers.register("browser", browser_renderer)
renderers.register("olli", olli_renderer)
renderers.enable("default")
Expand Down
18 changes: 6 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -457,18 +457,12 @@ include=[
"./doc/*.py",
"./tests/**/*.py",
"./tools/**/*.py",
"./sphinxext/**/*.py",
]
ignore=[
"./altair/vegalite/v5/display.py",
"./altair/vegalite/v5/schema/",
"./altair/utils/core.py",
"./altair/utils/_dfi_types.py",
"./altair/_magics.py",
"./altair/jupyter/",
"./sphinxext/",
"./tests/test_jupyter_chart.py",
"./tests/utils/",
"./tests/test_magics.py",
"./tests/vegalite/v5/test_geo_interface.py",
"../../../**/Lib", # stdlib
"./altair/vegalite/v5/schema/channels.py", # 716 warns
"./altair/vegalite/v5/schema/mixins.py", # 1001 warns
"./altair/jupyter/", # Mostly untyped
"./tests/test_jupyter_chart.py", # Based on untyped module
"../../../**/Lib", # stdlib
]
6 changes: 3 additions & 3 deletions sphinxext/altairgallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from docutils import nodes
from docutils.parsers.rst import Directive
from docutils.parsers.rst.directives import flag
from docutils.statemachine import ViewList
from docutils.statemachine import StringList
from sphinx.util.nodes import nested_parse_with_titles

from altair.utils.execeval import eval_block
Expand Down Expand Up @@ -184,7 +184,7 @@ def save_example_pngs(
else:
# the file changed or the image file does not exist. Generate it.
print(f"-> saving {image_file!s}")
chart = eval_block(code)
chart = eval_block(code, strict=True)
try:
chart.save(image_file)
hashes[filename] = example_hash
Expand Down Expand Up @@ -303,7 +303,7 @@ def run(self) -> list[Node]:
)

# parse and return documentation
result = ViewList()
result = StringList()
for line in include.split("\n"):
result.append(line, "<altair-minigallery>")
node = nodes.paragraph()
Expand Down
22 changes: 12 additions & 10 deletions sphinxext/code_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def gen() -> Iterator[nodes.Node]:


def theme_names() -> tuple[Sequence[str], Sequence[str]]:
names: set[VegaThemes] = set(get_args(VegaThemes))
names: set[str] = set(get_args(VegaThemes))
carbon = {nm for nm in names if nm.startswith("carbon")}
return ["default", *sorted(names - carbon)], sorted(carbon)

Expand Down Expand Up @@ -180,8 +180,8 @@ class ThemeDirective(SphinxDirective):
https://pyscript.net/
"""

has_content: ClassVar[Literal[False]] = False
required_arguments: ClassVar[Literal[1]] = 1
has_content: ClassVar[bool] = False
required_arguments: ClassVar[int] = 1
option_spec = {
"packages": validate_packages,
"dropdown-label": directives.unchanged,
Expand Down Expand Up @@ -226,14 +226,16 @@ def run(self) -> Sequence[nodes.Node]:
)
results.append(raw_html("</div></p>\n"))
return maybe_details(
results, self.options, default_summary="Show Vega-Altair Theme Test"
results,
self.options, # pyright: ignore[reportArgumentType]
default_summary="Show Vega-Altair Theme Test",
)


class PyScriptDirective(SphinxDirective):
"""Placeholder for non-theme related directive."""

has_content: ClassVar[Literal[False]] = False
has_content: ClassVar[bool] = False
option_spec = {"packages": directives.unchanged}

def run(self) -> Sequence[nodes.Node]:
Expand Down Expand Up @@ -282,9 +284,9 @@ class CodeRefDirective(SphinxDirective):
https://github.com/vega/sphinxext-altair
"""

has_content: ClassVar[Literal[False]] = False
required_arguments: ClassVar[Literal[1]] = 1
option_spec: ClassVar[dict[_Option, Callable[[str], Any]]] = {
has_content: ClassVar[bool] = False
required_arguments: ClassVar[int] = 1
option_spec: ClassVar[dict[_Option, Callable[[str], Any]]] = { # pyright: ignore[reportIncompatibleVariableOverride]
"output": validate_output,
"fold": directives.flag,
"summary": directives.unchanged_required,
Expand All @@ -302,8 +304,8 @@ def __init__(
state: RSTState,
state_machine: RSTStateMachine,
) -> None:
super().__init__(name, arguments, options, content, lineno, content_offset, block_text, state, state_machine) # fmt: skip
self.options: dict[_Option, Any]
super().__init__(name, arguments, options, content, lineno, content_offset, block_text, state, state_machine) # fmt: skip # pyright: ignore[reportArgumentType]
self.options: dict[_Option, Any] # pyright: ignore[reportIncompatibleVariableOverride]

def run(self) -> Sequence[nodes.Node]:
qual_name = self.arguments[0]
Expand Down
15 changes: 7 additions & 8 deletions sphinxext/schematable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from docutils import frontend, nodes, utils
from docutils.parsers.rst import Directive
from docutils.parsers.rst.directives import flag
from myst_parser.docutils_ import Parser
from myst_parser.parsers.docutils_ import Parser
from sphinx import addnodes

from tools.schemapi.utils import SchemaInfo, fix_docstring_issues
Expand Down Expand Up @@ -95,7 +95,7 @@ def add_text(node: nodes.paragraph, text: str) -> nodes.paragraph:
for part in reCode.split(text):
if part:
if is_text:
node += nodes.Text(part, part)
node += nodes.Text(part, part) # pyright: ignore[reportCallIssue]
else:
node += nodes.literal(part, part)

Expand All @@ -108,7 +108,7 @@ def build_row(
item: tuple[str, dict[str, Any]], rootschema: dict[str, Any] | None
) -> nodes.row:
"""Return nodes.row with property description."""
prop, propschema, _ = item
prop, propschema = item
row = nodes.row()

# Property
Expand Down Expand Up @@ -165,17 +165,16 @@ def build_schema_table(

def select_items_from_schema(
schema: dict[str, Any], props: list[str] | None = None
) -> Iterator[tuple[Any, Any, bool] | tuple[str, Any, bool]]:
"""Return iterator (prop, schema.item, required) on prop, return all in None."""
) -> Iterator[tuple[Any, Any] | tuple[str, Any]]:
"""Return iterator (prop, schema.item) on prop, return all in None."""
properties = schema.get("properties", {})
required = schema.get("required", [])
if not props:
for prop, item in properties.items():
yield prop, item, prop in required
yield prop, item
else:
for prop in props:
try:
yield prop, properties[prop], prop in required
yield prop, properties[prop]
except KeyError as err:
msg = f"Can't find property: {prop}"
raise Exception(msg) from err
Expand Down
12 changes: 6 additions & 6 deletions sphinxext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_generic_image(
arr = np.zeros((shape[0], shape[1], 3))
if gradient:
# gradient from gray to white
arr += np.linspace(128, 255, shape[1])[:, None]
arr += np.linspace(128, 255, shape[1])[:, None] # pyright: ignore[reportCallIssue,reportArgumentType]
im = Image.fromarray(arr.astype("uint8"))
im.save(filename)

Expand Down Expand Up @@ -138,12 +138,12 @@ def get_docstring_and_rest(filename: str) -> tuple[str, str | None, str, int]:
try:
# In python 3.7 module knows its docstring.
# Everything else will raise an attribute error
docstring = node.docstring
docstring = node.docstring # pyright: ignore[reportAttributeAccessIssue]

import tokenize
from io import BytesIO

ts = tokenize.tokenize(BytesIO(content).readline)
ts = tokenize.tokenize(BytesIO(content).readline) # pyright: ignore[reportArgumentType]
ds_lines = 0
# find the first string according to the tokenizer and get
# it's end row
Expand All @@ -163,7 +163,7 @@ def get_docstring_and_rest(filename: str) -> tuple[str, str | None, str, int]:
and isinstance(node.body[0].value, (ast.Str, ast.Constant))
):
docstring_node = node.body[0]
docstring = docstring_node.value.s
docstring = docstring_node.value.s # pyright: ignore[reportAttributeAccessIssue]
# python2.7: Code was read in bytes needs decoding to utf-8
# unless future unicode_literals is imported in source which
# make ast output unicode strings
Expand Down Expand Up @@ -203,8 +203,8 @@ def dict_hash(dct: dict[Any, Any]) -> Any:
serialized = json.dumps(dct, sort_keys=True)

try:
m = hashlib.sha256(serialized)[:32]
m = hashlib.sha256(serialized)[:32] # pyright: ignore[reportArgumentType,reportIndexIssue]
except TypeError:
m = hashlib.sha256(serialized.encode())[:32]
m = hashlib.sha256(serialized.encode())[:32] # pyright: ignore[reportIndexIssue]

return m.hexdigest()
Loading

0 comments on commit cf4a043

Please sign in to comment.