diff --git a/src/data_common/charting/chart.py b/src/data_common/charting/chart.py index d6f7d5a..425409d 100644 --- a/src/data_common/charting/chart.py +++ b/src/data_common/charting/chart.py @@ -1,7 +1,6 @@ from __future__ import annotations from functools import wraps -from pathlib import Path from typing import TYPE_CHECKING, Any, List, Optional, TypedDict, Union import altair as alt @@ -101,18 +100,28 @@ def display_options( self._display_options["caption"] = caption return self - def save(self, dest: Path): - pil_image = self.get_pil_image() - pil_image.save(dest) + def display(self, *args, **kwargs): + """ + Display the chart + """ + + custom = {k: getattr(self, k) for k in self.__class__.ignore_properties} + kwargs["custom"] = custom + super().display(*args, **kwargs) + + def save(self, *args, **kwargs): + """ + Save the chart + """ + custom = {k: getattr(self, k) for k in self.__class__.ignore_properties} + kwargs["custom"] = custom + super().save(*args, **kwargs) def to_dict(self, *args, ignore: Optional[List] = None, **kwargs) -> dict: if ignore is None: ignore = [] ignore += self.__class__.ignore_properties - value = super().to_dict(*args, ignore=ignore, **kwargs) - for k in ignore: - value["custom"] = {k: getattr(self, k)} - return value + return super().to_dict(*args, ignore=ignore, **kwargs) # Layering and stacking def __add__(self, other): @@ -130,7 +139,6 @@ def __or__(self, other): raise ValueError("Only Chart objects can be concatenated.") return hconcat(self, other) - @wraps(alt.Chart.properties) def raw_properties(self, *args, **kwargs): return super().properties(*args, **kwargs) diff --git a/src/data_common/charting/renderer.py b/src/data_common/charting/renderer.py index 821bf9c..bac85ad 100644 --- a/src/data_common/charting/renderer.py +++ b/src/data_common/charting/renderer.py @@ -58,7 +58,7 @@ def pil_image_to_mimebundle(img: Image.Image) -> MimeBundle: def render(spec: dict, embed_options: dict[str, Any]) -> MimeBundle: - display = spec.get("custom", {}).get( + display = embed_options.get("custom", {}).get( "_display_options", {"scale_factor": 1, "logo": "", "caption": ""} ) scale_factor = display["scale_factor"]