Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uow: add "on_exception" lifecycle hook #176

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions invenio_db/uow.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def on_post_commit(self, uow):
"""Called right after the commit phase."""
pass

def on_exception(self, uow, exception):
"""Called in case of an exception."""
pass

def on_rollback(self, uow):
"""Called in the rollback phase (after the transaction rollback)."""
pass
Expand Down Expand Up @@ -165,10 +169,10 @@ def __enter__(self):
"""Entering the context."""
return self

def __exit__(self, exc_type, *args):
def __exit__(self, exc_type, exc_value, traceback):
"""Rollback on exception."""
if exc_type is not None:
self.rollback()
self.rollback(exception=exc_value)
self._mark_dirty()

@property
Expand All @@ -193,9 +197,18 @@ def commit(self):
op.on_post_commit(self)
self._mark_dirty()

def rollback(self):
def rollback(self, exception=None):
"""Rollback the database session."""
self.session.rollback()

# Run exception operations
if exception:
for op in self._operations:
op.on_exception(self, exception)

# Commit exception operations
self.session.commit()

# Run rollback operations
for op in self._operations:
op.on_rollback(self)
Expand Down
130 changes: 130 additions & 0 deletions tests/test_uow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Unit of work tests."""

from unittest.mock import MagicMock

from invenio_db import InvenioDB
from invenio_db.uow import ModelCommitOp, Operation, UnitOfWork


def test_uow_lifecycle(db, app):
InvenioDB(app, entry_point_group=False, db=db)

with app.app_context():
mock_op = MagicMock()

# Test normal lifecycle
with UnitOfWork(db.session) as uow:
uow.register(mock_op)
uow.commit()

mock_op.on_register.assert_called_once()
mock_op.on_commit.assert_called_once()
mock_op.on_post_commit.assert_called_once()

mock_op.on_exception.assert_not_called()
mock_op.on_rollback.assert_not_called()
mock_op.on_post_rollback.assert_not_called()

# Test rollback lifecycle
mock_op.reset_mock()
with UnitOfWork(db.session) as uow:
uow.register(mock_op)
uow.rollback()

mock_op.on_register.assert_called_once()
mock_op.on_commit.assert_not_called()
mock_op.on_post_commit.assert_not_called()

# on_exception is not called (since there was no exception)
mock_op.on_exception.assert_not_called()

# rest of the rollback lifecycle is called
mock_op.on_rollback.assert_called_once()
mock_op.on_post_rollback.assert_called_once()

# Test exception lifecycle
mock_op.reset_mock()
try:
with UnitOfWork(db.session) as uow:
uow.register(mock_op)
raise Exception()
except Exception:
pass

mock_op.on_register.assert_called_once()
mock_op.on_commit.assert_not_called()
mock_op.on_post_commit.assert_not_called()

# both exception and rollback lifecycle are called
mock_op.on_exception.assert_called_once()
mock_op.on_rollback.assert_called_once()
mock_op.on_post_rollback.assert_called_once()


def test_uow_transactions(db, app):
"""Test transaction behavior with the Unit of Work."""

class Data(db.Model):
value = db.Column(db.String(100), primary_key=True)

InvenioDB(app, entry_point_group=False, db=db)

rollback_side_effect = MagicMock()
post_rollback_side_effect = MagicMock()

class CleanUpOp(Operation):
def on_exception(self, uow, exception):
uow.session.add(Data(value="clean-up"))

on_rollback = rollback_side_effect
on_post_rollback = post_rollback_side_effect

with app.app_context():
db.create_all()

# Test normal lifecycle
with UnitOfWork(db.session) as uow:
uow.register(ModelCommitOp(Data(value="persisted")))
uow.commit()

data = db.session.query(Data).all()
assert len(data) == 1
assert data[0].value == "persisted"

# Test rollback lifecycle
with UnitOfWork(db.session) as uow:
uow.register(ModelCommitOp(Data(value="not-persisted")))
uow.register(CleanUpOp())
uow.rollback()

data = db.session.query(Data).all()
assert len(data) == 1
assert data[0].value == "persisted"

rollback_side_effect.assert_called_once()
post_rollback_side_effect.assert_called_once()

# Test exception lifecycle
rollback_side_effect.reset_mock()
post_rollback_side_effect.reset_mock()
try:
with UnitOfWork(db.session) as uow:
uow.register(ModelCommitOp(Data(value="not-persisted")))
uow.register(CleanUpOp())
raise Exception()
except Exception:
pass

data = db.session.query(Data).all()
assert len(data) == 2
assert set([d.value for d in data]) == {"persisted", "clean-up"}

rollback_side_effect.assert_called_once()
post_rollback_side_effect.assert_called_once()