From fdfd6180af458942eec1e836a47039e298376ac2 Mon Sep 17 00:00:00 2001 From: Alex Ioannidis Date: Wed, 30 Oct 2024 11:01:19 +0100 Subject: [PATCH] uow: add "on_exception" lifecycle hook - When an exception occurs and the unit of work is being rolled back, we want to add the possibility to perform some clean-up actions that can also be commited to the database. - The new `on_exception` method is added because we want to keep backwards compatibility and also introduce the same "triplet" of methods as we have for the "happy path" (`on_register`, `on_commit`, `on_post_commit`). --- invenio_db/uow.py | 19 +++++-- tests/test_uow.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 tests/test_uow.py diff --git a/invenio_db/uow.py b/invenio_db/uow.py index f5830f3..73ec442 100644 --- a/invenio_db/uow.py +++ b/invenio_db/uow.py @@ -108,6 +108,10 @@ def on_post_commit(self, uow): """Called right after the commit phase.""" pass + def on_exception(self, uow, exception): + """Called in case of an exception.""" + pass + def on_rollback(self, uow): """Called in the rollback phase (after the transaction rollback).""" pass @@ -165,10 +169,10 @@ def __enter__(self): """Entering the context.""" return self - def __exit__(self, exc_type, *args): + def __exit__(self, exc_type, exc_value, traceback): """Rollback on exception.""" if exc_type is not None: - self.rollback() + self.rollback(exception=exc_value) self._mark_dirty() @property @@ -193,9 +197,18 @@ def commit(self): op.on_post_commit(self) self._mark_dirty() - def rollback(self): + def rollback(self, exception=None): """Rollback the database session.""" self.session.rollback() + + # Run exception operations + if exception: + for op in self._operations: + op.on_exception(self, exception) + + # Commit exception operations + self.session.commit() + # Run rollback operations for op in self._operations: op.on_rollback(self) diff --git a/tests/test_uow.py b/tests/test_uow.py new file mode 100644 index 0000000..aa94902 --- /dev/null +++ b/tests/test_uow.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2024 CERN. +# +# Invenio is free software; you can redistribute it and/or modify it +# under the terms of the MIT License; see LICENSE file for more details. + +"""Unit of work tests.""" + +from unittest.mock import MagicMock + +from invenio_db import InvenioDB +from invenio_db.uow import ModelCommitOp, Operation, UnitOfWork + + +def test_uow_lifecycle(db, app): + InvenioDB(app, entry_point_group=False, db=db) + + with app.app_context(): + mock_op = MagicMock() + + # Test normal lifecycle + with UnitOfWork(db.session) as uow: + uow.register(mock_op) + uow.commit() + + mock_op.on_register.assert_called_once() + mock_op.on_commit.assert_called_once() + mock_op.on_post_commit.assert_called_once() + + mock_op.on_exception.assert_not_called() + mock_op.on_rollback.assert_not_called() + mock_op.on_post_rollback.assert_not_called() + + # Test rollback lifecycle + mock_op.reset_mock() + with UnitOfWork(db.session) as uow: + uow.register(mock_op) + uow.rollback() + + mock_op.on_register.assert_called_once() + mock_op.on_commit.assert_not_called() + mock_op.on_post_commit.assert_not_called() + + # on_exception is not called (since there was no exception) + mock_op.on_exception.assert_not_called() + + # rest of the rollback lifecycle is called + mock_op.on_rollback.assert_called_once() + mock_op.on_post_rollback.assert_called_once() + + # Test exception lifecycle + mock_op.reset_mock() + try: + with UnitOfWork(db.session) as uow: + uow.register(mock_op) + raise Exception() + except Exception: + pass + + mock_op.on_register.assert_called_once() + mock_op.on_commit.assert_not_called() + mock_op.on_post_commit.assert_not_called() + + # both exception and rollback lifecycle are called + mock_op.on_exception.assert_called_once() + mock_op.on_rollback.assert_called_once() + mock_op.on_post_rollback.assert_called_once() + + +def test_uow_transactions(db, app): + """Test transaction behavior with the Unit of Work.""" + + class Data(db.Model): + value = db.Column(db.String(100), primary_key=True) + + InvenioDB(app, entry_point_group=False, db=db) + + rollback_side_effect = MagicMock() + post_rollback_side_effect = MagicMock() + + class CleanUpOp(Operation): + def on_exception(self, uow, exception): + uow.session.add(Data(value="clean-up")) + + on_rollback = rollback_side_effect + on_post_rollback = post_rollback_side_effect + + with app.app_context(): + db.create_all() + + # Test normal lifecycle + with UnitOfWork(db.session) as uow: + uow.register(ModelCommitOp(Data(value="persisted"))) + uow.commit() + + data = db.session.query(Data).all() + assert len(data) == 1 + assert data[0].value == "persisted" + + # Test rollback lifecycle + with UnitOfWork(db.session) as uow: + uow.register(ModelCommitOp(Data(value="not-persisted"))) + uow.register(CleanUpOp()) + uow.rollback() + + data = db.session.query(Data).all() + assert len(data) == 1 + assert data[0].value == "persisted" + + rollback_side_effect.assert_called_once() + post_rollback_side_effect.assert_called_once() + + # Test exception lifecycle + rollback_side_effect.reset_mock() + post_rollback_side_effect.reset_mock() + try: + with UnitOfWork(db.session) as uow: + uow.register(ModelCommitOp(Data(value="not-persisted"))) + uow.register(CleanUpOp()) + raise Exception() + except Exception: + pass + + data = db.session.query(Data).all() + assert len(data) == 2 + assert set([d.value for d in data]) == {"persisted", "clean-up"} + + rollback_side_effect.assert_called_once() + post_rollback_side_effect.assert_called_once()