From 25d10595d2e2d280d9651e5e66ed3e5f74635cb6 Mon Sep 17 00:00:00 2001 From: Jochem van Dooren Date: Wed, 13 Mar 2024 14:33:27 +0100 Subject: [PATCH] Fix decorator and change manifest loading --- src/dbt_score/manifest.py | 125 ------------------- src/dbt_score/models.py | 176 +++++++++++++++++++++++++++ src/dbt_score/rule.py | 60 +++++++-- src/dbt_score/rules/__init__.py | 0 src/dbt_score/rules/example_rules.py | 79 ++++++------ src/dbt_score/utils.py | 20 --- 6 files changed, 271 insertions(+), 189 deletions(-) delete mode 100644 src/dbt_score/manifest.py create mode 100644 src/dbt_score/models.py delete mode 100644 src/dbt_score/rules/__init__.py delete mode 100644 src/dbt_score/utils.py diff --git a/src/dbt_score/manifest.py b/src/dbt_score/manifest.py deleted file mode 100644 index 4494eb2..0000000 --- a/src/dbt_score/manifest.py +++ /dev/null @@ -1,125 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, List - - -@dataclass -class Constraint: - """Constraint for a column in a model.""" - - type: str - expression: str - name: str - - -@dataclass -class Test: - """Test for a column or model.""" - - name: str - type: str - tags: list[str] = field(default_factory=list) - - -@dataclass -class Column: - """Represents a column in a model.""" - - name: str - description: str - constraints: List[Constraint] - tests: List[Test] = field(default_factory=list) - - -@dataclass -class Model: - """Represents a dbt model.""" - - id: str - name: str - description: str - file_path: str - config: dict[str, Any] - meta: dict[str, Any] - columns: dict[str, Column] - tests: list[Test] = field(default_factory=list) - - @classmethod - def from_node(cls, node_values: dict[str, Any]) -> "Model": - """Create a model object from a node in the manifest.""" - columns = { - name: Column( - name=values.get("name"), - description=values.get("description"), - constraints=[ - Constraint( - name=constraint.get("name"), - type=constraint.get("type"), - expression=constraint.get("expression"), - ) - for constraint in values.get("constraints", []) - ], - ) - for name, values in node_values.get("columns", {}).items() - } - - model = cls( - id=node_values["unique_id"], - file_path=node_values["patch_path"], - config=node_values.get("config", {}), - name=node_values["name"], - description=node_values.get("description", ""), - meta=node_values.get("meta", {}), - columns=columns, - ) - - return model - - -class ManifestLoader: - """Load the models and tests from the manifest.""" - - def __init__(self, raw_manifest: dict[str, Any]): - self.raw_manifest = raw_manifest - self.raw_nodes = raw_manifest.get("nodes", {}) - self.models: dict[str, Model] = {} - self.tests: dict[str, Test] = {} - - # Load models first so the tests can be attached to them later. - self.load_models() - self.load_tests() - - def load_models(self) -> None: - """Load the models from the manifest.""" - for node_values in self.raw_nodes.values(): - if node_values.get("resource_type") == "model": - model = Model.from_node(node_values) - self.models[model.id] = model - - def load_tests(self) -> None: - """Load the tests from the manifest and attach them to the right object.""" - for node_values in self.raw_nodes.values(): - # Only include tests that are attached to a model. - if node_values.get("resource_type") == "test" and node_values.get( - "attached_node" - ): - model = self.models.get(node_values.get("attached_node")) - - if not model: - raise ValueError( - f"Model {node_values.get('attached_node')}" - f"not found, while tests are attached to it." - ) - - test = Test( - name=node_values.get("name"), - type=node_values.get("test_metadata").get("name"), - tags=node_values.get("tags"), - ) - column_name = ( - node_values.get("test_metadata").get("kwargs").get("column_name") - ) - - if column_name: # Test is a column-level test. - model.columns[column_name].tests.append(test) - else: - model.tests.append(test) diff --git a/src/dbt_score/models.py b/src/dbt_score/models.py new file mode 100644 index 0000000..293cc9c --- /dev/null +++ b/src/dbt_score/models.py @@ -0,0 +1,176 @@ +"""Objects related to loading the dbt manifest.""" +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Constraint: + """Constraint for a column. + + Args: + name: The name of the constraint. + type: The type of the constraint, e.g. `foreign_key`. + expression: The expression of the constraint, e.g. `schema.other_table`. + """ + + name: str + type: str + expression: str + + +@dataclass +class Test: + """Test for a column or model. + + Args: + name: The name of the test. + type: The type of the test, e.g. `unique`. + tags: The list of tags attached to the test. + """ + + name: str + type: str + tags: list[str] = field(default_factory=list) + + +@dataclass +class Column: + """Represents a column in a model. + + Args: + name: The name of the column. + description: The description of the column. + constraints: The list of constraints attached to the column. + tests: The list of tests attached to the column. + """ + + name: str + description: str + constraints: list[Constraint] = field(default_factory=list) + tests: list[Test] = field(default_factory=list) + + +@dataclass +class Model: + """Represents a dbt model. + + Args: + id: The id of the model, e.g. `model.package.model_name`. + name: The name of the model. + description: The full description of the model. + file_path: The `.yml` file path of the model. + config: The config of the model. + meta: The meta of the model. + columns: The list of columns of the model. + tests: The list of tests attached to the model. + """ + + id: str + name: str + description: str + file_path: str + config: dict[str, Any] + meta: dict[str, Any] + columns: list[Column] + tests: list[Test] = field(default_factory=list) + + def get_column(self, column_name: str) -> Column | None: + """Get a column by name.""" + for column in self.columns: + if column.name == column_name: + return column + + return None + + @staticmethod + def _get_columns( + node_values: dict[str, Any], tests_values: list[dict[str, Any]] + ) -> list[Column]: + """Get columns from a node and it's tests in the manifest.""" + columns = [ + Column( + name=values.get("name"), + description=values.get("description"), + constraints=[ + Constraint( + name=constraint.get("name"), + type=constraint.get("type"), + expression=constraint.get("expression"), + ) + for constraint in values.get("constraints", []) + ], + tests=[ + Test( + name=test["name"], + type=test["test_metadata"]["name"], + tags=test.get("tags", []), + ) + for test in tests_values + if test["test_metadata"].get("kwargs", {}).get("column_name") + == values.get("name") + ], + ) + for name, values in node_values.get("columns", {}).items() + ] + return columns + + @classmethod + def from_node( + cls, node_values: dict[str, Any], tests_values: list[dict[str, Any]] + ) -> "Model": + """Create a model object from a node and it's tests in the manifest.""" + model = cls( + id=node_values["unique_id"], + file_path=node_values["patch_path"], + config=node_values.get("config", {}), + name=node_values["name"], + description=node_values.get("description", ""), + meta=node_values.get("meta", {}), + columns=cls._get_columns(node_values, tests_values), + tests=[ + Test( + name=test["name"], + type=test["test_metadata"]["name"], + tags=test.get("tags", []), + ) + for test in tests_values + if not test["test_metadata"].get("kwargs", {}).get("column_name") + ], + ) + + return model + + +class ManifestLoader: + """Load the models and tests from the manifest.""" + + def __init__(self, raw_manifest: dict[str, Any]): + """Initialize the ManifestLoader. + + Args: + raw_manifest: The dictionary representation of the JSON manifest. + """ + self.raw_manifest = raw_manifest + self.raw_nodes = raw_manifest.get("nodes", {}) + self.models: list[Model] = [] + self.tests: dict[str, list[dict[str, Any]]] = defaultdict(list) + + self._reindex_tests() + self._load_models() + + def _load_models(self) -> None: + """Load the models from the manifest.""" + for node_id, node_values in self.raw_nodes.items(): + if node_values.get("resource_type") == "model": + model = Model.from_node(node_values, self.tests.get(node_id, [])) + self.models.append(model) + + def _reindex_tests(self) -> None: + """Index tests based on their model id.""" + for node_values in self.raw_nodes.values(): + # Only include tests that are attached to a model. + if node_values.get("resource_type") == "test" and node_values.get( + "attached_node" + ): + self.tests[node_values["attached_node"]].append(node_values) diff --git a/src/dbt_score/rule.py b/src/dbt_score/rule.py index 5609c8d..50d506b 100644 --- a/src/dbt_score/rule.py +++ b/src/dbt_score/rule.py @@ -1,10 +1,13 @@ +"""Rule definitions.""" + + import functools import logging from dataclasses import dataclass from enum import Enum -from typing import Any, Callable +from typing import Any, Callable, Type -from dbt_score.manifest import Model +from dbt_score.models import Model logging.basicConfig() logger = logging.getLogger(__name__) @@ -27,21 +30,62 @@ class RuleViolation: message: str | None = None +class Rule: + """The rule base class.""" + + description: str + severity: Severity = Severity.MEDIUM + + def __init_subclass__(cls, **kwargs) -> None: # type: ignore + """Initializes the subclass.""" + super().__init_subclass__(**kwargs) + if not hasattr(cls, "description"): + raise TypeError("Subclass must define class attribute `description`.") + + @classmethod + def evaluate(cls, model: Model) -> RuleViolation | None: + """Evaluates the rule.""" + raise NotImplementedError("Subclass must implement class method `evaluate`.") + + def rule( - description: str, - hint: str, + description: str | None = None, severity: Severity = Severity.MEDIUM, -) -> Callable[[Callable[[Model], RuleViolation | None]], Callable[..., None]]: - """Rule decorator.""" +) -> Callable[[Callable[[Model], RuleViolation | None]], Type[Rule]]: + """Rule decorator. + + The rule decorator creates a rule class (subclass of Rule) and returns it. + + Args: + description: The description of the rule. + severity: The severity of the rule. + """ def decorator_rule( func: Callable[[Model], RuleViolation | None], - ) -> Callable[..., None]: + ) -> Type[Rule]: @functools.wraps(func) def wrapper_rule(*args: Any, **kwargs: Any) -> Any: logger.debug("Executing `%s` with severity: %s.", func.__name__, severity) return func(*args, **kwargs) - return wrapper_rule + # Create the rule class + if func.__doc__ is None and description is None: + raise TypeError("Rule must define `description` or `func.__doc__`.") + + rule_description = description or ( + func.__doc__.split("\n")[0] if func.__doc__ else None + ) + rule_class = type( + func.__name__, + (Rule,), + { + "description": rule_description, + "severity": severity, + "evaluate": wrapper_rule, + }, + ) + + return rule_class return decorator_rule diff --git a/src/dbt_score/rules/__init__.py b/src/dbt_score/rules/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/dbt_score/rules/example_rules.py b/src/dbt_score/rules/example_rules.py index 6404d6b..8bc59c6 100644 --- a/src/dbt_score/rules/example_rules.py +++ b/src/dbt_score/rules/example_rules.py @@ -1,28 +1,46 @@ """All general rules.""" -from ..manifest import Model -from ..rule import RuleViolation, Severity, rule +from dbt_score.models import Model +from dbt_score.rule import Rule, RuleViolation, Severity, rule -@rule( - description="A model should have an owner defined.", - hint="Define the owner of the model in the meta section.", - severity=Severity.HIGH, -) +class ComplexRule(Rule): + """Complex rule.""" + + description = "Example of a complex rule." + severity = Severity.CRITICAL + + @classmethod + def preprocess(cls) -> int: + """Preprocessing.""" + return 1 + + @classmethod + def evaluate(cls, model: Model) -> RuleViolation | None: + """Evaluate model.""" + x = cls.preprocess() + + if x: + return RuleViolation(str(x)) + else: + return None + + +@rule() def has_owner(model: Model) -> RuleViolation | None: """A model should have an owner defined.""" if "owner" not in model.meta: - return RuleViolation() + return RuleViolation("Define the owner of the model in the meta section.") return None -@rule(description="A model should have a primary key defined.", hint="Some hint.") +@rule() def has_primary_key(model: Model) -> RuleViolation | None: """A model should have a primary key defined, unless it's a view.""" if not model.config.get("materialized") == "picnic_view": has_pk = False - for column in model.columns.values(): + for column in model.columns: if "primary_key" in [constraint.type for constraint in column.constraints]: has_pk = True break @@ -33,23 +51,16 @@ def has_primary_key(model: Model) -> RuleViolation | None: return None -@rule( - description="Primary key columns should have a uniqueness test defined.", - hint="Some hint.", -) +@rule() def primary_key_has_uniqueness_test(model: Model) -> RuleViolation | None: """Primary key columns should have a uniqueness test defined.""" columns_with_pk = [] - if not model.config.get("materialized") == "picnic_view": - for column_name, column in model.columns.items(): + if model.config.get("materialized") == "view": + for column in model.columns: if "primary_key" in [constraint.type for constraint in column.constraints]: - columns_with_pk.append(column_name) + columns_with_pk.append(column) - tests = ( - model.columns[columns_with_pk[0]].tests - if len(columns_with_pk) == 1 - else model.tests - ) + tests = columns_with_pk[0].tests if len(columns_with_pk) == 1 else model.tests if columns_with_pk and "unique" not in [test.type for test in tests]: return RuleViolation() @@ -57,37 +68,33 @@ def primary_key_has_uniqueness_test(model: Model) -> RuleViolation | None: return None -@rule( - description="All columns of a model should have a description.", hint="Some hint." -) +@rule() def columns_have_description(model: Model) -> RuleViolation | None: """All columns of a model should have a description.""" - invalid_columns = [ - column_name - for column_name, column in model.columns.items() - if not column.description + invalid_column_names = [ + column.name for column in model.columns if not column.description ] - if invalid_columns: + if invalid_column_names: return RuleViolation( message=f"The following columns lack a description: " - f"{', '.join(invalid_columns)}." + f"{', '.join(invalid_column_names)}." ) return None -@rule(description="A model should have at least one test defined.", hint="Some hint.") +@rule(description="A model should have at least one test defined.") def has_test(model: Model) -> RuleViolation | None: - """A model should have at least one model-level and one column-level test. + """A model should have at least one model-level or column-level test defined. This does not include singular tests, which are tests defined in a separate .sql file and not linked to the model in the metadata. """ column_tests = [] - for column in model.columns.values(): + for column in model.columns: column_tests.extend(column.tests) - if len(model.tests) == 0 or len(column_tests) == 0: - return RuleViolation() + if len(model.tests) == 0 and len(column_tests) == 0: + return RuleViolation("Define a test for the model on model- or column-level.") return None diff --git a/src/dbt_score/utils.py b/src/dbt_score/utils.py deleted file mode 100644 index b5ee0c3..0000000 --- a/src/dbt_score/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Utility functions.""" - -import json -from pathlib import Path -from typing import Any - - -class JsonOpenError(RuntimeError): - """Raised when there is an error opening a JSON file.""" - - pass - - -def get_json(json_filename: str) -> Any: - """Get JSON from a file.""" - try: - file_content = Path(json_filename).read_text(encoding="utf-8") - return json.loads(file_content) - except Exception as e: - raise JsonOpenError(f"Error opening {json_filename}.") from e