diff --git a/platformics/api/core/strawberry_extensions.py b/platformics/api/core/strawberry_extensions.py index 78232f7..e64a3d4 100644 --- a/platformics/api/core/strawberry_extensions.py +++ b/platformics/api/core/strawberry_extensions.py @@ -10,6 +10,7 @@ from strawberry.extensions import FieldExtension from strawberry.field import StrawberryField from strawberry.types import Info +from typing import Any, Awaitable, Callable def get_func_with_only_deps(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]: @@ -35,6 +36,52 @@ def get_func_with_only_deps(func: typing.Callable[..., typing.Any]) -> typing.Ca return newfunc +class RegisteredPlatformicsPlugins: + plugins: dict[str, typing.Callable[..., typing.Any]] = {} + + @classmethod + def register(cls, callback_order: str, type: str, action: str, callback: typing.Callable[..., typing.Any]) -> None: + cls.plugins[f"{callback_order}:{type}:{action}"] = callback + + @classmethod + def getCallback(cls, callback_order: str, type: str, action: str) -> typing.Callable[..., typing.Any] | None: + return cls.plugins.get(f"{callback_order}:{type}:{action}") + + +def register_plugin(callback_order: str, type: str, action: str) -> Callable[..., Callable[..., Any]]: + def decorator_register(func: Callable[..., Any]) -> Callable[..., Any]: + RegisteredPlatformicsPlugins.register(callback_order, type, action, func) + return func + + return decorator_register + + +class PlatformicsPluginExtension(FieldExtension): + def __init__(self, type: str, action: str) -> None: + self.type = type + self.action = action + self.strawberry_field_names = ["self"] + + async def resolve_async( + self, + next_: typing.Callable[..., typing.Any], + source: typing.Any, + info: Info, + **kwargs: dict[str, typing.Any], + ) -> typing.Any: + before_callback = RegisteredPlatformicsPlugins.getCallback("before", self.type, self.action) + if before_callback: + before_callback(source, info, **kwargs) + + result = await next_(source, info, **kwargs) + + after_callback = RegisteredPlatformicsPlugins.getCallback("after", self.type, self.action) + if after_callback: + result = after_callback(result, source, info, **kwargs) + + return result + + class DependencyExtension(FieldExtension): def __init__(self) -> None: self.dependency_args: list[typing.Any] = [] diff --git a/platformics/codegen/templates/api/types/class_name.py.j2 b/platformics/codegen/templates/api/types/class_name.py.j2 index 559cd12..d7aaebe 100644 --- a/platformics/codegen/templates/api/types/class_name.py.j2 +++ b/platformics/codegen/templates/api/types/class_name.py.j2 @@ -40,7 +40,7 @@ from fastapi import Depends from platformics.api.core.errors import PlatformicsError from platformics.api.core.deps import get_cerbos_client, get_db_session, require_auth_principal, is_system_user from platformics.api.core.query_input_types import aggregator_map, orderBy, EnumComparators, DatetimeComparators, IntComparators, FloatComparators, StrComparators, UUIDComparators, BoolComparators -from platformics.api.core.strawberry_extensions import DependencyExtension +from platformics.api.core.strawberry_extensions import DependencyExtension, PlatformicsPluginExtension from platformics.security.authorization import CerbosAction, get_resource_query from sqlalchemy import inspect from sqlalchemy.engine.row import RowMapping @@ -500,7 +500,7 @@ async def resolve_{{ cls.plural_snake_name }}_aggregate( return aggregate_output {%- if cls.create_fields %} -@strawberry.mutation(extensions=[DependencyExtension()]) +@strawberry.mutation(extensions=[DependencyExtension(), PlatformicsPluginExtension("{{ cls.snake_name }}", "create")]) async def create_{{ cls.snake_name }}( input: {{ cls.name }}CreateInput, session: AsyncSession = Depends(get_db_session, use_cache=False), @@ -559,7 +559,7 @@ async def create_{{ cls.snake_name }}( {%- if cls.mutable_fields %} -@strawberry.mutation(extensions=[DependencyExtension()]) +@strawberry.mutation(extensions=[DependencyExtension(), PlatformicsPluginExtension("{{ cls.snake_name }}", "update")]) async def update_{{ cls.snake_name }}( input: {{ cls.name }}UpdateInput, where: {{ cls.name }}WhereClauseMutations, @@ -634,7 +634,7 @@ async def update_{{ cls.snake_name }}( {%- endif %} -@strawberry.mutation(extensions=[DependencyExtension()]) +@strawberry.mutation(extensions=[DependencyExtension(), PlatformicsPluginExtension("{{ cls.snake_name }}", "delete")]) async def delete_{{ cls.snake_name }}( where: {{ cls.name }}WhereClauseMutations, session: AsyncSession = Depends(get_db_session, use_cache=False), diff --git a/test_app/conftest.py b/test_app/conftest.py index 2c48927..0c6dfd1 100644 --- a/test_app/conftest.py +++ b/test_app/conftest.py @@ -231,6 +231,7 @@ async def patched_session() -> typing.AsyncGenerator[AsyncSession, None]: def raise_exception() -> str: raise Exception("Unexpected error") + # Subclass Query with an additional field to test Exception handling. @strawberry.type class MyQuery(Query): @@ -239,6 +240,7 @@ def uncaught_exception(self) -> str: # Trigger an AttributeException return self.kaboom # type: ignore + @pytest_asyncio.fixture() async def api_test_schema(async_db: AsyncDB) -> FastAPI: """ diff --git a/test_app/main.py b/test_app/main.py index db4ccd7..9c6fbf6 100644 --- a/test_app/main.py +++ b/test_app/main.py @@ -4,13 +4,24 @@ import strawberry import uvicorn +from api.types.sample import SampleCreateInput from platformics.api.setup import get_app, get_strawberry_config from platformics.api.core.error_handler import HandleErrors from platformics.settings import APISettings from database import models +from platformics.api.core.strawberry_extensions import register_plugin from api.mutations import Mutation from api.queries import Query +from typing import Any +from strawberry.types import Info + + +@register_plugin("before", "sample", "create") +def validate_sample_name(source: Any, info: Info, **kwargs: SampleCreateInput) -> None: + if kwargs["input"].name == "foo": + raise ValueError("Sample name cannot be 'foo'") + settings = APISettings.model_validate({}) # Workaround for https://github.com/pydantic/pydantic/issues/3753 schema = strawberry.Schema(query=Query, mutation=Mutation, config=get_strawberry_config(), extensions=[HandleErrors()])