Skip to content

Commit

Permalink
Merge pull request #83 from mts-ai/sqla-data-layer-refactor
Browse files Browse the repository at this point in the history
Sqla dl refactor
  • Loading branch information
mahenzon authored Apr 12, 2024
2 parents a5dfc0d + edcec1e commit bb270ae
Showing 1 changed file with 91 additions and 42 deletions.
133 changes: 91 additions & 42 deletions fastapi_jsonapi/data_layers/sqla_orm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module is a CRUD interface between resource managers and the sqlalchemy ORM"""
import logging
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Tuple, Type, Union

from sqlalchemy import delete, func, select
from sqlalchemy.exc import DBAPIError, IntegrityError, MissingGreenlet, NoResultFound
Expand Down Expand Up @@ -44,6 +44,9 @@

log = logging.getLogger(__name__)

ModelTypeOneOrMany = Union[TypeModel, list[TypeModel]]
ActionTrigger = Literal["create", "update"]


class SqlalchemyDataLayer(BaseDataLayer):
"""Sqlalchemy data layer"""
Expand Down Expand Up @@ -134,12 +137,88 @@ def prepare_id_value(self, col: InstrumentedAttribute, value: Any) -> Any:

return value

async def apply_relationships(self, obj: TypeModel, data_create: BaseJSONAPIItemInSchema) -> None:
async def link_relationship_object(
self,
obj: TypeModel,
relation_name: str,
related_data: Optional[ModelTypeOneOrMany],
action_trigger: ActionTrigger,
):
"""
Links target object with relationship object or objects
:param obj:
:param relation_name:
:param related_data:
:param action_trigger: indicates which one operation triggered relationships applying
"""
# todo: relation name may be different?
setattr(obj, relation_name, related_data)

async def check_object_has_relationship_or_raise(self, obj: TypeModel, relation_name: str):
"""
TODO: move generic code to another method
Checks that there is relationship with relation_name in obj
:param obj:
:param relation_name:
"""
try:
hasattr(obj, relation_name)
except MissingGreenlet:
raise InternalServerError(
detail=(
f"Error of loading the {relation_name!r} relationship. "
f"Please add this relationship to include query parameter explicitly."
),
parameter="include",
)

async def get_related_data_to_link(
self,
related_model: TypeModel,
relationship_info: RelationshipInfo,
relationship_in: Union[
BaseJSONAPIRelationshipDataToOneSchema,
BaseJSONAPIRelationshipDataToManySchema,
],
) -> Optional[ModelTypeOneOrMany]:
"""
Retrieves object or objects to link from database
:param related_model:
:param relationship_info:
:param relationship_in:
"""
if not relationship_in.data:
return [] if relationship_info.many else None

if relationship_info.many:
assert isinstance(relationship_in, BaseJSONAPIRelationshipDataToManySchema)
return await self.get_related_objects_list(
related_model=related_model,
related_id_field=relationship_info.id_field_name,
ids=[r.id for r in relationship_in.data],
)

assert isinstance(relationship_in, BaseJSONAPIRelationshipDataToOneSchema)
return await self.get_related_object(
related_model=related_model,
related_id_field=relationship_info.id_field_name,
id_value=relationship_in.data.id,
)

async def apply_relationships(
self,
obj: TypeModel,
data_create: BaseJSONAPIItemInSchema,
action_trigger: ActionTrigger,
) -> None:
"""
Handles relationships passed in request
:param obj:
:param data_create:
:param action_trigger: indicates which one operation triggered relationships applying
:return:
"""
relationships: "PydanticBaseModel" = data_create.relationships
Expand Down Expand Up @@ -167,45 +246,15 @@ async def apply_relationships(self, obj: TypeModel, data_create: BaseJSONAPIItem
continue

relationship_info: RelationshipInfo = field.field_info.extra["relationship"]

# ...
related_model = get_related_model_cls(type(obj), relation_name)
related_data = await self.get_related_data_to_link(
related_model=related_model,
relationship_info=relationship_info,
relationship_in=relationship_in,
)

if relationship_info.many:
assert isinstance(relationship_in, BaseJSONAPIRelationshipDataToManySchema)

related_data = []
if relationship_in.data:
related_data = await self.get_related_objects_list(
related_model=related_model,
related_id_field=relationship_info.id_field_name,
ids=[r.id for r in relationship_in.data],
)
else:
assert isinstance(relationship_in, BaseJSONAPIRelationshipDataToOneSchema)

if relationship_in.data:
related_data = await self.get_related_object(
related_model=related_model,
related_id_field=relationship_info.id_field_name,
id_value=relationship_in.data.id,
)
else:
setattr(obj, relation_name, None)
continue
try:
hasattr(obj, relation_name)
except MissingGreenlet:
raise InternalServerError(
detail=(
f"Error of loading the {relation_name!r} relationship. "
f"Please add this relationship to include query parameter explicitly."
),
parameter="include",
)

# todo: relation name may be different?
setattr(obj, relation_name, related_data)
await self.check_object_has_relationship_or_raise(obj, relation_name)
await self.link_relationship_object(obj, relation_name, related_data, action_trigger)

async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs: dict) -> TypeModel:
"""
Expand All @@ -222,7 +271,7 @@ async def create_object(self, data_create: BaseJSONAPIItemInSchema, view_kwargs:
await self.before_create_object(model_kwargs=model_kwargs, view_kwargs=view_kwargs)

obj = self.model(**model_kwargs)
await self.apply_relationships(obj, data_create)
await self.apply_relationships(obj, data_create, action_trigger="create")

self.session.add(obj)
try:
Expand Down Expand Up @@ -348,7 +397,7 @@ async def update_object(
"""
new_data = data_update.attributes.dict(exclude_unset=True)

await self.apply_relationships(obj, data_update)
await self.apply_relationships(obj, data_update, action_trigger="update")

await self.before_update_object(obj, model_kwargs=new_data, view_kwargs=view_kwargs)

Expand Down

0 comments on commit bb270ae

Please sign in to comment.