Skip to content

Commit

Permalink
Merge pull request #24 from mohamadkhalaj/main
Browse files Browse the repository at this point in the history
Support `Cond` in `add_field()`, this function also works with kwargs…
  • Loading branch information
seyed-dev authored Nov 1, 2023
2 parents ee282c2 + dff30e2 commit 5d6850b
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 53 deletions.
16 changes: 10 additions & 6 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mongoengine import Document, EmbeddedDocument, fields
from mongoengine.base import TopLevelDocumentMetaclass

from aggify.compiler import F, Match, Q, Operators # noqa keep
from aggify.compiler import F, Match, Q, Operators, Cond # noqa keep
from aggify.exceptions import (
AggifyValueError,
AnnotationError,
Expand Down Expand Up @@ -70,10 +70,12 @@ def group(self, expression: str | None = "_id") -> "Aggify":
return self

@last_out_stage_check
def order_by(self, field: str) -> "Aggify":
self.pipelines.append(
{"$sort": {f'{field.replace("-", "")}': -1 if field.startswith("-") else 1}}
)
def order_by(self, *fields: str | list[str]) -> "Aggify":
sort_dict = {
field.replace("-", ""): -1 if field.startswith("-") else 1
for field in fields
}
self.pipelines.append({"$sort": sort_dict})
return self

@last_out_stage_check
Expand All @@ -82,7 +84,7 @@ def raw(self, raw_query: dict) -> "Aggify":
return self

@last_out_stage_check
def add_fields(self, fields: dict) -> "Aggify": # noqa
def add_fields(self, **fields) -> "Aggify": # noqa
"""
Generates a MongoDB addFields pipeline stage.
Expand All @@ -99,6 +101,8 @@ def add_fields(self, fields: dict) -> "Aggify": # noqa
add_fields_stage["$addFields"][field] = {"$literal": expression}
elif isinstance(expression, F):
add_fields_stage["$addFields"][field] = expression.to_dict()
elif isinstance(expression, Cond):
add_fields_stage["$addFields"][field] = dict(expression)
else:
raise AggifyValueError([str, F], type(expression))

Expand Down
14 changes: 11 additions & 3 deletions aggify/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def is_base_model_field(self, field) -> bool:
and the base_model is not None, otherwise False.
"""
return self.base_model is not None and (
isinstance(self.base_model._fields.get(field), (EmbeddedDocumentField, TopLevelDocumentMetaclass)) # noqa
isinstance(
self.base_model._fields.get(field),
(EmbeddedDocumentField, TopLevelDocumentMetaclass),
) # noqa
)

def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
Expand All @@ -263,7 +266,10 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
raise InvalidOperator(key)

field, operator, *_ = key.split("__")
if self.is_base_model_field(field) and operator not in Operators.ALL_OPERATORS:
if (
self.is_base_model_field(field)
and operator not in Operators.ALL_OPERATORS
):
pipelines.append(
Match({key.replace("__", ".", 1): value}, self.base_model).compile(
[]
Expand All @@ -274,6 +280,8 @@ def compile(self, pipelines: list) -> dict[str, dict[str, list]]:
if operator not in Operators.ALL_OPERATORS:
raise InvalidOperator(operator)
db_field = get_db_field(self.base_model, field)
match_query = Operators(match_query).compile_match(operator, value, db_field)
match_query = Operators(match_query).compile_match(
operator, value, db_field
)

return {"$match": match_query}
6 changes: 4 additions & 2 deletions aggify/utilty.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ def check_field_exists(model: Type[Document], field: str) -> None:
raise AlreadyExistsField(field=field)


def get_db_field(model: Type[Document], field: str) -> str:
def get_db_field(model: Type[Document], field: str, add_dollar_sign=False) -> str:
"""
Get the database field name for a given field in the model.
Args:
add_dollar_sign: Add a "$" at the start of the field or not
model (Document): The model containing the field.
field (str): The name of the field.
Expand All @@ -143,6 +144,7 @@ def get_db_field(model: Type[Document], field: str) -> str:
"""
try:
db_field = model._fields.get(field).db_field # noqa
return field if db_field is None else db_field
db_field = field if db_field is None else db_field
return f"${db_field}" if add_dollar_sign else db_field
except AttributeError:
return field
6 changes: 3 additions & 3 deletions tests/test_aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def test_add_field_value_error(self):
fields = {
"new_field_1": True,
}
aggify.add_fields(fields)
aggify.add_fields(**fields)

def test_add_fields_string_literal(self):
aggify = Aggify(BaseModel)
fields = {"new_field_1": "some_string", "new_field_2": "another_string"}
add_fields_stage = aggify.add_fields(fields)
add_fields_stage = aggify.add_fields(**fields)

expected_stage = {
"$addFields": {
Expand All @@ -146,7 +146,7 @@ def test_add_fields_with_f_expression(self):
"new_field_1": F("existing_field") + 10,
"new_field_2": F("field_a") * F("field_b"),
}
add_fields_stage = aggify.add_fields(fields)
add_fields_stage = aggify.add_fields(**fields)

expected_stage = {
"$addFields": {
Expand Down
123 changes: 84 additions & 39 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,25 @@ class ParameterTestCase:
compiled_query=Aggify(PostDocument).filter(
caption__contains="hello", owner__deleted_at=None
),
expected_query=[{'$match': {'caption': {'$options': 'i', '$regex': '.*hello.*'}}},
{'$lookup': {'as': 'owner',
'foreignField': '_id',
'from': 'account',
'localField': 'owner_id'}},
{'$unwind': {'includeArrayIndex': None,
'path': '$owner',
'preserveNullAndEmptyArrays': True}},
{'$match': {'owner.deleted_at': None}}],
expected_query=[
{"$match": {"caption": {"$options": "i", "$regex": ".*hello.*"}}},
{
"$lookup": {
"as": "owner",
"foreignField": "_id",
"from": "account",
"localField": "owner_id",
}
},
{
"$unwind": {
"includeArrayIndex": None,
"path": "$owner",
"preserveNullAndEmptyArrays": True,
}
},
{"$match": {"owner.deleted_at": None}},
],
),
ParameterTestCase(
compiled_query=Aggify(PostDocument)
Expand Down Expand Up @@ -116,7 +126,7 @@ class ParameterTestCase:
ParameterTestCase(
compiled_query=(
Aggify(PostDocument).add_fields(
{
**{
"new_field_1": "some_string",
"new_field_2": F("existing_field") + 10,
"new_field_3": F("field_a") * F("field_b"),
Expand Down Expand Up @@ -179,14 +189,28 @@ class ParameterTestCase:
)
.filter(_posts1__ne=[])
),
expected_query=[{'$lookup': {'as': '_posts1',
'from': 'account',
'let': {'owner': '$owner_id'},
'pipeline': [{'$match': {'$expr': {'$and': [{'$ne': ['$_id',
'$$owner']},
{'$ne': ['$username',
'seyed']}]}}}]}},
{'$match': {'_posts1': {'$ne': []}}}],
expected_query=[
{
"$lookup": {
"as": "_posts1",
"from": "account",
"let": {"owner": "$owner_id"},
"pipeline": [
{
"$match": {
"$expr": {
"$and": [
{"$ne": ["$_id", "$$owner"]},
{"$ne": ["$username", "seyed"]},
]
}
}
}
],
}
},
{"$match": {"_posts1": {"$ne": []}}},
],
),
ParameterTestCase(
compiled_query=(
Expand All @@ -199,13 +223,20 @@ class ParameterTestCase:
)
.filter(_posts2__ne=[])
),
expected_query=[{'$lookup': {'as': '_posts2',
'from': 'account',
'let': {'caption': '$caption', 'owner': '$owner_id'},
'pipeline': [{'$match': {'$expr': {'$eq': ['$_id', '$$owner']}}},
{'$match': {'$expr': {'$eq': ['$username',
'$$caption']}}}]}},
{'$match': {'_posts2': {'$ne': []}}}],
expected_query=[
{
"$lookup": {
"as": "_posts2",
"from": "account",
"let": {"caption": "$caption", "owner": "$owner_id"},
"pipeline": [
{"$match": {"$expr": {"$eq": ["$_id", "$$owner"]}}},
{"$match": {"$expr": {"$eq": ["$username", "$$caption"]}}},
],
}
},
{"$match": {"_posts2": {"$ne": []}}},
],
),
ParameterTestCase(
compiled_query=(Aggify(PostDocument).replace_root(embedded_field="stat")),
Expand Down Expand Up @@ -255,32 +286,46 @@ class ParameterTestCase:
),
ParameterTestCase(
compiled_query=(
Aggify(PostDocument)
.lookup(
Aggify(PostDocument).lookup(
AccountDocument,
local_field='owner', foreign_field='id',
local_field="owner",
foreign_field="id",
as_name="_owner",
)
),
expected_query=[{'$lookup': {'as': '_owner',
'foreignField': '_id',
'from': 'account',
'localField': 'owner_id'}}],
expected_query=[
{
"$lookup": {
"as": "_owner",
"foreignField": "_id",
"from": "account",
"localField": "owner_id",
}
}
],
),
ParameterTestCase(
compiled_query=(
Aggify(PostDocument)
.lookup(
AccountDocument,
local_field='owner', foreign_field='id',
local_field="owner",
foreign_field="id",
as_name="_owner1",
).filter(_owner1__username='Aggify')
)
.filter(_owner1__username="Aggify")
),
expected_query=[{'$lookup': {'as': '_owner1',
'foreignField': '_id',
'from': 'account',
'localField': 'owner_id'}},
{'$match': {'_owner1.username': 'Aggify'}}],
expected_query=[
{
"$lookup": {
"as": "_owner1",
"foreignField": "_id",
"from": "account",
"localField": "owner_id",
}
},
{"$match": {"_owner1.username": "Aggify"}},
],
),
]

Expand Down

0 comments on commit 5d6850b

Please sign in to comment.