Skip to content

Commit

Permalink
Fixed issue when Databricks SDK config objects were overridden for in…
Browse files Browse the repository at this point in the history
…stallation config files (#170)

closes #169

---------

Co-authored-by: Serge Smertin <[email protected]>
  • Loading branch information
FastLee and nfx authored Nov 14, 2024
1 parent cb1b8f0 commit ab4edc6
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 40 deletions.
82 changes: 43 additions & 39 deletions src/databricks/labs/blueprint/installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,66 +469,62 @@ def _get_list_type_ref(inst: T) -> type[list[T]]:
item_type = type(from_list[0]) # type: ignore[misc]
return list[item_type] # type: ignore[valid-type]

@classmethod
def _marshal(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
def _marshal(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
"""The `_marshal` method is a private method that is used to serialize an object of type `type_ref` to
a dictionary. This method is called by the `save` method."""
if inst is None:
return None, False
if isinstance(inst, databricks.sdk.core.Config):
return self._marshal_databricks_config(inst)
if hasattr(inst, "as_dict"):
return inst.as_dict(), True
if dataclasses.is_dataclass(type_ref):
return cls._marshal_dataclass(type_ref, path, inst)
if isinstance(inst, databricks.sdk.core.Config):
return inst.as_dict(), True
return self._marshal_dataclass(type_ref, path, inst)
if type_ref == list:
return cls._marshal_list(type_ref, path, inst)
return self._marshal_list(type_ref, path, inst)
if isinstance(type_ref, enum.EnumMeta):
return cls._marshal_enum(inst)
return self._marshal_enum(inst)
if type_ref == types.NoneType:
return inst, inst is None
if type_ref == databricks.sdk.core.Config:
return cls._marshal_databricks_config(inst)
if type_ref in cls._PRIMITIVES:
if type_ref in self._PRIMITIVES:
return inst, True
return cls._marshal_generic_types(type_ref, path, inst)
return self._marshal_generic_types(type_ref, path, inst)

@classmethod
def _marshal_generic_types(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
def _marshal_generic_types(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
# pylint: disable-next=import-outside-toplevel,import-private-name
from typing import ( # type: ignore[attr-defined]
_GenericAlias,
_UnionGenericAlias,
)

if isinstance(type_ref, (types.UnionType, _UnionGenericAlias)): # type: ignore[attr-defined]
return cls._marshal_union(type_ref, path, inst)
return self._marshal_union(type_ref, path, inst)
if isinstance(type_ref, (_GenericAlias, types.GenericAlias)): # type: ignore[attr-defined]
if type_ref.__origin__ in (dict, list) or isinstance(type_ref, types.GenericAlias):
return cls._marshal_generic(type_ref, path, inst)
return cls._marshal_generic_alias(type_ref, inst)
return self._marshal_generic(type_ref, path, inst)
return self._marshal_generic_alias(type_ref, inst)
raise SerdeError(f'{".".join(path)}: unknown: {inst}')

@classmethod
def _marshal_union(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
def _marshal_union(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
"""The `_marshal_union` method is a private method that is used to serialize an object of type `type_ref` to
a dictionary. This method is called by the `save` method."""
combo = []
for variant in get_args(type_ref):
value, ok = cls._marshal(variant, [*path, f"(as {variant})"], inst)
value, ok = self._marshal(variant, [*path, f"(as {variant})"], inst)
if ok:
return value, True
combo.append(cls._explain_why(variant, [*path, f"(as {variant})"], inst))
combo.append(self._explain_why(variant, [*path, f"(as {variant})"], inst))
raise SerdeError(f'{".".join(path)}: union: {" or ".join(combo)}')

@classmethod
def _marshal_generic(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
def _marshal_generic(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
"""The `_marshal_generic` method is a private method that is used to serialize an object of type `type_ref`
to a dictionary. This method is called by the `save` method."""
type_args = get_args(type_ref)
if not type_args:
raise SerdeError(f"Missing type arguments: {type_args}")
if len(type_args) == 2:
return cls._marshal_dict(type_args[1], path, inst)
return cls._marshal_list(type_args[0], path, inst)
return self._marshal_dict(type_args[1], path, inst)
return self._marshal_list(type_args[0], path, inst)

@staticmethod
def _marshal_generic_alias(type_ref, inst):
Expand All @@ -538,35 +534,32 @@ def _marshal_generic_alias(type_ref, inst):
return None, False
return inst, isinstance(inst, type_ref.__origin__) # type: ignore[attr-defined]

@classmethod
def _marshal_list(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
def _marshal_list(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
"""The `_marshal_list` method is a private method that is used to serialize an object of type `type_ref` to
a dictionary. This method is called by the `save` method."""
as_list = []
if not isinstance(inst, list):
return None, False
for i, v in enumerate(inst):
value, ok = cls._marshal(type_ref, [*path, f"{i}"], v)
value, ok = self._marshal(type_ref, [*path, f"{i}"], v)
if not ok:
raise SerdeError(cls._explain_why(type_ref, [*path, f"{i}"], v))
raise SerdeError(self._explain_why(type_ref, [*path, f"{i}"], v))
as_list.append(value)
return as_list, True

@classmethod
def _marshal_dict(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
def _marshal_dict(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
"""The `_marshal_dict` method is a private method that is used to serialize an object of type `type_ref` to
a dictionary. This method is called by the `save` method."""
if not isinstance(inst, dict):
return None, False
as_dict = {}
for k, v in inst.items():
as_dict[k], ok = cls._marshal(type_ref, [*path, k], v)
as_dict[k], ok = self._marshal(type_ref, [*path, k], v)
if not ok:
raise SerdeError(cls._explain_why(type_ref, [*path, k], v))
raise SerdeError(self._explain_why(type_ref, [*path, k], v))
return as_dict, True

@classmethod
def _marshal_dataclass(cls, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
def _marshal_dataclass(self, type_ref: type, path: list[str], inst: Any) -> tuple[Any, bool]:
"""The `_marshal_dataclass` method is a private method that is used to serialize an object of type `type_ref`
to a dictionary. This method is called by the `save` method."""
if inst is None:
Expand All @@ -577,21 +570,29 @@ def _marshal_dataclass(cls, type_ref: type, path: list[str], inst: Any) -> tuple
if origin is typing.ClassVar:
continue
raw = getattr(inst, field)
value, ok = cls._marshal(hint, [*path, field], raw)
if not raw:
continue
value, ok = self._marshal(hint, [*path, field], raw)
if not ok:
raise SerdeError(cls._explain_why(hint, [*path, field], raw))
raise SerdeError(self._explain_why(hint, [*path, field], raw))
if not value:
continue
as_dict[field] = value
return as_dict, True

@staticmethod
def _marshal_databricks_config(inst):
def _marshal_databricks_config(self, inst):
"""The `_marshal_databricks_config` method is a private method that is used to serialize an object of type
`databricks.sdk.core.Config` to a dictionary. This method is called by the `save` method."""
if not inst:
return None, False
return inst.as_dict(), True
current_client_config = self._current_client_config()
remote_file_config = inst.as_dict()
if current_client_config == remote_file_config:
return None, True
return remote_file_config, True

def _current_client_config(self) -> dict:
return self._ws.config.as_dict()

@staticmethod
def _marshal_enum(inst):
Expand Down Expand Up @@ -886,6 +887,9 @@ def files(self) -> list[workspace.ObjectInfo]:
def remove(self):
self._removed = True

def _current_client_config(self) -> dict:
return {}

def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type):
self._overwrites[filename] = as_dict

Expand Down
1 change: 0 additions & 1 deletion src/databricks/labs/blueprint/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from databricks.labs.blueprint import _posixpath


logger = logging.getLogger(__name__)


Expand Down

0 comments on commit ab4edc6

Please sign in to comment.