Skip to content

Commit

Permalink
Merge pull request #28 from mohamadkhalaj/main
Browse files Browse the repository at this point in the history
Fix issue #26
  • Loading branch information
mohamadkhalaj authored Nov 1, 2023
2 parents bc2e29d + 5a47462 commit 164e6e1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
48 changes: 46 additions & 2 deletions aggify/aggify.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,48 @@ def __init__(self, base_model: Type[Document]):

@last_out_stage_check
def project(self, **kwargs: QueryParams) -> "Aggify":
"""
Adjusts the base model's fields based on the given keyword arguments.
Fields to be retained are set to 1 in kwargs.
Fields to be deleted are set to 0 in kwargs, except for _id which is controlled by the delete_id flag.
Args:
**kwargs: Fields to be retained or removed.
For example: {"field1": 1, "field2": 0}
_id field behavior: {"_id": 0} means delete _id.
Returns:
Aggify: Returns an instance of the Aggify class for potential method chaining.
"""

# Extract fields to keep and check if _id should be deleted
to_keep_values = ["id"]
delete_id = kwargs.get("_id") == 0

# Add missing fields to the base model
for key, value in kwargs.items():
if value == 1:
to_keep_values.append(key)
elif key not in self.base_model._fields and isinstance( # noqa
kwargs[key], str
): # noqa
to_keep_values.append(key)
self.base_model._fields[key] = fields.IntField() # noqa

# Remove fields from the base model, except the ones in to_keep_values and possibly _id
keys_for_deletion = set(self.base_model._fields.keys()) - set( # noqa
to_keep_values
) # noqa
if delete_id:
keys_for_deletion.add("id")
for key in keys_for_deletion:
del self.base_model._fields[key] # noqa

# Append the projection stage to the pipelines
self.pipelines.append({"$project": kwargs})

# Return the instance for method chaining
return self

@last_out_stage_check
Expand All @@ -87,7 +128,7 @@ def raw(self, raw_query: dict) -> "Aggify":
return self

@last_out_stage_check
def add_fields(self, **fields) -> "Aggify": # noqa
def add_fields(self, **_fields) -> "Aggify": # noqa
"""
Generates a MongoDB addFields pipeline stage.
Expand All @@ -99,7 +140,8 @@ def add_fields(self, **fields) -> "Aggify": # noqa
"""
add_fields_stage = {"$addFields": {}}

for field, expression in fields.items():
for field, expression in _fields.items():
field = field.replace("__", ".")
if isinstance(expression, str):
add_fields_stage["$addFields"][field] = {"$literal": expression}
elif isinstance(expression, F):
Expand All @@ -108,6 +150,8 @@ def add_fields(self, **fields) -> "Aggify": # noqa
add_fields_stage["$addFields"][field] = dict(expression)
else:
raise AggifyValueError([str, F], type(expression))
# TODO: Should be checked if new field is embedded, create embedded field.
self.base_model._fields[field.replace("$", "")] = fields.IntField() # noqa

self.pipelines.append(add_fields_stage)
return self
Expand Down
2 changes: 1 addition & 1 deletion aggify/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __invert__(self):
class F:
def __init__(self, field: str | dict[str, list]):
if isinstance(field, str):
self.field = f"${field}"
self.field = f"${field.replace('__', '.')}"
else:
self.field = field

Expand Down

0 comments on commit 164e6e1

Please sign in to comment.