Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1616817 Instantiate entity from model instance #1825

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 6 additions & 32 deletions src/snowflake/cli/_plugins/workspace/manager.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from pathlib import Path
from typing import Dict

from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.console import cli_console as cc
from snowflake.cli.api.entities.common import EntityActions, get_sql_executor
from snowflake.cli.api.exceptions import InvalidProjectDefinitionVersionError
from snowflake.cli.api.project.definition import default_role
from snowflake.cli.api.project.schemas.entities.entities import (
Entity,
v2_entity_model_to_entity_map,
from snowflake.cli.api.entities.common import (
EntityActions,
)
from snowflake.cli.api.exceptions import InvalidProjectDefinitionVersionError
from snowflake.cli.api.project.schemas.entities.entities import Entity
from snowflake.cli.api.project.schemas.project_definition import (
DefinitionV20,
ProjectDefinition,
)
from snowflake.cli.api.project.util import to_identifier


class WorkspaceManager:
Expand All @@ -41,15 +37,7 @@ def get_entity(self, entity_id: str):
entity_model = self._project_definition.entities.get(entity_id, None)
if entity_model is None:
raise ValueError(f"No such entity ID: {entity_id}")
entity_model_cls = entity_model.__class__
entity_cls = v2_entity_model_to_entity_map[entity_model_cls]
workspace_ctx = WorkspaceContext(
console=cc,
project_root=self.project_root,
get_default_role=_get_default_role,
get_default_warehouse=_get_default_warehouse,
)
self._entities_cache[entity_id] = entity_cls(entity_model, workspace_ctx)
self._entities_cache[entity_id] = entity_model.get_entity(cc, self.project_root)
return self._entities_cache[entity_id]

def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs):
Expand All @@ -68,17 +56,3 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs)
@property
def project_root(self) -> Path:
return self._project_root


def _get_default_role() -> str:
role = default_role()
if role is None:
role = get_sql_executor().current_role()
return role


def _get_default_warehouse() -> str | None:
warehouse = get_cli_context().connection.warehouse
if warehouse:
warehouse = to_identifier(warehouse)
return warehouse
27 changes: 26 additions & 1 deletion src/snowflake/cli/api/entities/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,32 @@ class EntityActions(str, Enum):
T = TypeVar("T")


class EntityBase(Generic[T]):
class EntityBaseMetaclass(type):
def __new__(mcs, name, bases, attrs): # noqa: N804
cls = super().__new__(mcs, name, bases, attrs)
generic_bases = attrs.get("__orig_bases__", [])
if not generic_bases:
# Subclass is not generic
return cls

target_model_class = get_args(generic_bases[0])[0] # type: ignore[attr-defined]
if target_model_class is T:
# Generic parameter is not filled in
return cls

target_entity_class = getattr(target_model_class, "_entity_class", None)
if target_entity_class is not None:
raise ValueError(
f"Entity model class {target_model_class} is already "
f"associated with entity class {target_entity_class}, "
f"cannot associate with {cls}"
)

setattr(target_model_class, "_entity_class", cls)
return cls


class EntityBase(Generic[T], metaclass=EntityBaseMetaclass):
"""
Base class for the fully-featured entity classes.
"""
Expand Down
41 changes: 41 additions & 0 deletions src/snowflake/cli/api/project/schemas/entities/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from __future__ import annotations

from abc import ABC
from pathlib import Path
from typing import Dict, Generic, List, Optional, TypeVar, Union

from pydantic import Field, PrivateAttr, field_validator
from snowflake.cli._plugins.workspace.context import WorkspaceContext
from snowflake.cli.api.console.abc import AbstractConsole
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.project.schemas.updatable_model import (
IdentifierField,
Expand Down Expand Up @@ -110,6 +113,24 @@ def fqn(self) -> FQN:
if self.entity_id:
return FQN.from_string(self.entity_id)

def get_entity(self, console: AbstractConsole, project_root: Path):
if type(self) is EntityModelBase:
raise NotImplementedError
# Set by EntityBaseMetaclass when creating the
# Entity class that refers to this model
entity_class = getattr(self, "_entity_class", None)
if entity_class is None:
raise ValueError(
f"Entity model class {type(self).__name__} is not associated with an entity class"
)
workspace_ctx = WorkspaceContext(
console=console,
project_root=project_root,
get_default_role=_get_default_role,
get_default_warehouse=_get_default_warehouse,
)
return entity_class(self, workspace_ctx)


TargetType = TypeVar("TargetType")

Expand Down Expand Up @@ -162,3 +183,23 @@ def get_secrets_sql(self) -> str | None:
return None
secrets = ", ".join(f"'{key}'={value}" for key, value in self.secrets.items())
return f"secrets=({secrets})"


def _get_default_role() -> str:
from snowflake.cli.api.entities.common import get_sql_executor
from snowflake.cli.api.project.definition import default_role

role = default_role()
if role is None:
role = get_sql_executor().current_role()
return role


def _get_default_warehouse() -> str | None:
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.project.util import to_identifier

warehouse = get_cli_context().connection.warehouse
if warehouse:
warehouse = to_identifier(warehouse)
return warehouse
10 changes: 2 additions & 8 deletions tests/nativeapp/test_version_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
AskAlwaysPolicy,
DenyAlwaysPolicy,
)
from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext
from snowflake.cli._plugins.workspace.context import ActionContext
from snowflake.cli.api.console import cli_console as cc
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.connector.cursor import DictCursor
Expand Down Expand Up @@ -60,13 +60,7 @@ def _version_create(
dm = DefinitionManager()
pd = dm.project_definition
pkg_model: ApplicationPackageEntityModel = pd.entities["app_pkg"]
ctx = WorkspaceContext(
console=cc,
project_root=dm.project_root,
get_default_role=lambda: "mock_role",
get_default_warehouse=lambda: "mock_warehouse",
)
pkg = ApplicationPackageEntity(pkg_model, ctx)
pkg = pkg_model.get_entity(cc, dm.project_root)
return pkg.action_version_create(
action_ctx=mock.Mock(spec=ActionContext),
version=version,
Expand Down
25 changes: 0 additions & 25 deletions tests/project/test_project_definition_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@
)
from snowflake.cli.api.project.definition_manager import DefinitionManager
from snowflake.cli.api.project.errors import SchemaValidationError
from snowflake.cli.api.project.schemas.entities.entities import (
ALL_ENTITIES,
ALL_ENTITY_MODELS,
v2_entity_model_to_entity_map,
v2_entity_model_types_map,
)
from snowflake.cli.api.project.schemas.project_definition import (
DefinitionV20,
)
Expand Down Expand Up @@ -310,25 +304,6 @@ def test_identifiers():
assert entities["D"].entity_id == "D"


# Verify that each entity model type has the correct "type" field
def test_entity_types():
for entity_type, entity_class in v2_entity_model_types_map.items():
model_entity_type = entity_class.get_type()
assert model_entity_type == entity_type


# Verify that each entity class has a corresponding entity model class, and that all entities are covered
def test_entity_model_to_entity_map():
entities = set(ALL_ENTITIES)
entity_models = set(ALL_ENTITY_MODELS)
assert len(entities) == len(entity_models)
for entity_model_class, entity_class in v2_entity_model_to_entity_map.items():
entities.remove(entity_class)
entity_models.remove(entity_model_class)
assert len(entities) == 0
assert len(entity_models) == 0


@pytest.mark.parametrize(
"project_name",
[
Expand Down
Loading