From eabd970b84c1f37d8c5734404f208535bd6da42f Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Mon, 22 Apr 2024 19:02:57 -0600 Subject: [PATCH] Finalize async generators in the correct context --- newsfragments/92.feature.rst | 4 ++ tests/conftest.py | 5 +- tests/test_trio_asyncio.py | 94 ++++++++++++++++++++++++++++++++++++ trio_asyncio/_base.py | 74 ++++++++++++++++++++++++++++ trio_asyncio/_loop.py | 56 ++++++++++++++++++++- 5 files changed, 228 insertions(+), 5 deletions(-) create mode 100644 newsfragments/92.feature.rst diff --git a/newsfragments/92.feature.rst b/newsfragments/92.feature.rst new file mode 100644 index 0000000..16104a3 --- /dev/null +++ b/newsfragments/92.feature.rst @@ -0,0 +1,4 @@ +trio-asyncio now properly finalizes asyncio-flavored async generators +upon closure of the event loop. Previously, Trio's async generator finalizers +would try to finalize all async generators in Trio mode, regardless of their +flavor, which could lead to spurious errors. diff --git a/tests/conftest.py b/tests/conftest.py index f72bb76..2fcd4c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,10 +12,7 @@ @pytest.fixture async def loop(): async with trio_asyncio.open_loop() as loop: - try: - yield loop - finally: - await loop.stop().wait() + yield loop # auto-trio-ize all async functions diff --git a/tests/test_trio_asyncio.py b/tests/test_trio_asyncio.py index 9e5c809..7b04d18 100644 --- a/tests/test_trio_asyncio.py +++ b/tests/test_trio_asyncio.py @@ -3,8 +3,10 @@ import types import asyncio import trio +import trio.testing import trio_asyncio import contextlib +import gc async def use_asyncio(): @@ -203,3 +205,95 @@ async def main(): asyncio.run(main()) assert scope.value.code == 42 + + +@pytest.mark.trio +@pytest.mark.parametrize("alive_on_exit", (False, True)) +@pytest.mark.parametrize("slow_finalizer", (False, True)) +@pytest.mark.parametrize("loop_timeout", (0, 1, 20)) +async def test_asyncgens(alive_on_exit, slow_finalizer, loop_timeout, autojump_clock): + import sniffio + + record = set() + holder = [] + + async def agen(label, extra): + assert sniffio.current_async_library() == label + if label == "asyncio": + loop = asyncio.get_running_loop() + try: + yield 1 + finally: + library = sniffio.current_async_library() + if label == "asyncio": + assert loop is asyncio.get_running_loop() + try: + await sys.modules[library].sleep(5 if slow_finalizer else 0) + except (trio.Cancelled, asyncio.CancelledError): + pass + record.add((label + extra, library)) + + async def iterate_one(label, extra=""): + ag = agen(label, extra) + await ag.asend(None) + if alive_on_exit: + holder.append(ag) + else: + del ag + + sys.unraisablehook, prev_hook = sys.__unraisablehook__, sys.unraisablehook + try: + start_time = trio.current_time() + with trio.move_on_after(loop_timeout) as scope: + if loop_timeout == 0: + scope.cancel() + async with trio_asyncio.open_loop() as loop: + async with trio_asyncio.open_loop() as loop2: + async with trio.open_nursery() as nursery: + # Make sure the iterate_one aio tasks don't get + # cancelled before they start: + nursery.cancel_scope.shield = True + try: + nursery.start_soon(iterate_one, "trio") + nursery.start_soon( + loop.run_aio_coroutine, iterate_one("asyncio") + ) + nursery.start_soon( + loop2.run_aio_coroutine, iterate_one("asyncio", "2") + ) + await loop.synchronize() + await loop2.synchronize() + finally: + nursery.cancel_scope.shield = False + + # asyncio agens should be finalized as soon as asyncio loop ends, + # regardless of liveness + assert ("asyncio", "asyncio") in record + assert ("asyncio2", "asyncio") in record + + # asyncio agen finalizers should be able to take a cancel + if (slow_finalizer or loop_timeout == 0) and alive_on_exit: + # Each loop finalizes in series, and takes 5 seconds + # if slow_finalizer is true. + assert trio.current_time() == start_time + min(loop_timeout, 10) + assert scope.cancelled_caught == (loop_timeout < 10) + else: + # `not alive_on_exit` implies that the asyncio agen aclose() tasks + # are started before loop shutdown, which means they'll be + # cancelled during loop shutdown; this matches regular asyncio. + # + # `not slow_finalizer and loop_timeout > 0` implies that the agens + # have time to complete before we cancel them. + assert trio.current_time() == start_time + assert not scope.cancelled_caught + + # trio asyncgen should eventually be finalized in trio mode + del holder[:] + for _ in range(5): + gc.collect() + await trio.testing.wait_all_tasks_blocked() + assert record == { + ("trio", "trio"), ("asyncio", "asyncio"), ("asyncio2", "asyncio") + } + finally: + sys.unraisablehook = prev_hook diff --git a/trio_asyncio/_base.py b/trio_asyncio/_base.py index a30112a..1bd4ddb 100644 --- a/trio_asyncio/_base.py +++ b/trio_asyncio/_base.py @@ -91,6 +91,71 @@ def shutdown(self, wait=None): self._running = False +class AsyncGeneratorDispatcher: + """Helper object providing async generator hooks that route + finalization to either the correct trio-asyncio event loop or the + outer Trio run, depending on where the generator was first iterated. + """ + + def __init__(self, prev_hooks): + self.prev_hooks = prev_hooks + self.refcnt = 1 + + @classmethod + def install(cls): + current_hooks = sys.get_asyncgen_hooks() + + # These hooks should either be our own AsyncGeneratorDispatcher + # (for another trio-asyncio loop) or Trio's hooks. Both of those + # provide both hooks. + assert current_hooks.firstiter is not None + assert current_hooks.finalizer is not None + + matches = ( + getattr(current_hooks.firstiter, "__func__", None) is cls.firstiter + ) + (getattr(current_hooks.finalizer, "__func__", None) is cls.finalizer) + if matches == 0: + # Create a new dispatcher that forwards non-trio-asyncio asyncgens + # to the current_hooks + dispatcher = cls(prev_hooks=current_hooks) + sys.set_asyncgen_hooks( + firstiter=dispatcher.firstiter, finalizer=dispatcher.finalizer + ) + else: + # Take a new reference to the dispatcher that the current_hooks + # refer to + assert matches == 2 + dispatcher = current_hooks.firstiter.__self__ + assert dispatcher is current_hooks.finalizer.__self__ + assert isinstance(dispatcher, cls) + dispatcher.refcnt += 1 + return dispatcher + + def uninstall(self): + self.refcnt -= 1 + if self.refcnt <= 0: + sys.set_asyncgen_hooks(*self.prev_hooks) + assert self.refcnt == 0 + + def firstiter(self, agen): + if sniffio_library.name == "asyncio": + loop = asyncio.get_running_loop() + agen.ag_frame.f_locals["@trio_asyncio_loop"] = loop + return loop._asyncgen_firstiter_hook(agen) + else: + return self.prev_hooks.firstiter(agen) + + def finalizer(self, agen): + try: + loop = agen.ag_frame.f_locals.get("@trio_asyncio_loop") + except AttributeError: # pragma: no cover + loop = None + if loop is not None: + return loop._asyncgen_finalizer_hook(agen) + else: + return self.prev_hooks.finalizer(agen) + + class BaseTrioEventLoop(asyncio.SelectorEventLoop): """An asyncio event loop that runs on top of Trio. @@ -135,6 +200,10 @@ class BaseTrioEventLoop(asyncio.SelectorEventLoop): # (threading) Thread this loop is running in _thread = None + # An instance of AsyncGeneratorDispatcher for handling asyncio async + # generators; it may be shared by multiple running trio-asyncio loops + _asyncgen_dispatcher = None + def __init__(self, queue_len=None): if queue_len is None: queue_len = math.inf @@ -629,6 +698,7 @@ async def _main_loop_init(self, nursery): self._nursery = nursery self._task = trio.lowlevel.current_task() self._token = trio.lowlevel.current_trio_token() + self._asyncgen_dispatcher = AsyncGeneratorDispatcher.install() async def _main_loop(self, task_status=trio.TASK_STATUS_IGNORED): """Run the loop by processing its event queue. @@ -738,6 +808,10 @@ async def _main_loop_exit(self): except TrioAsyncioExit: pass + # Restore previous async generator hooks + self._asyncgen_dispatcher.uninstall() + self._asyncgen_dispatcher = None + # Kill off unprocessed work self._cancel_fds() self._cancel_timers() diff --git a/trio_asyncio/_loop.py b/trio_asyncio/_loop.py index 5b353e2..a48e9f4 100644 --- a/trio_asyncio/_loop.py +++ b/trio_asyncio/_loop.py @@ -5,6 +5,7 @@ import sys import trio import asyncio +import warnings import threading from contextvars import ContextVar from contextlib import asynccontextmanager @@ -560,6 +561,49 @@ async def wait_for_sync(): tasks_nursery.cancel_scope.cancel() finally: + # If we have any async generators left, finalize them before + # closing the event loop. Make sure that the finalizers have a + # chance to actually start before they're exposed to any + # external cancellation, since asyncio doesn't guarantee that + # cancelled tasks have a chance to start first. + + asyncgens_done = trio.Event() + should_warn = False + if len(loop._asyncgens) == 0: + asyncgens_done.set() + elif not loop.is_running(): + asyncgens_done.set() + should_warn = True + else: + shield_asyncgen_finalizers = trio.CancelScope(shield=True) + + async def sentinel(): + try: + yield + finally: + try: + # Open-coded asyncio version of loop.synchronize(); + # since we closed the tasks_nursery, we can't do + # any more asyncio-to-trio-mode conversions + w = asyncio.Event() + loop.call_soon(w.set) + await w.wait() + finally: + shield_asyncgen_finalizers.shield = False + + async def shutdown_asyncgens_from_aio(): + agen = sentinel() + await agen.asend(None) + try: + await loop.shutdown_asyncgens() + finally: + asyncgens_done.set() + + @loop_nursery.start_soon + async def shutdown_asyncgens_from_trio(): + with shield_asyncgen_finalizers: + await loop.run_aio_coroutine(shutdown_asyncgens_from_aio()) + if forwarded_cancellation is not None: # Now that we're outside the shielded tasks_nursery, we can # add this cancellation to the set of errors propagating out @@ -570,7 +614,17 @@ async def forward_cancellation(): raise forwarded_cancellation try: - await loop._main_loop_exit() + try: + if should_warn: + warnings.warn( + "trio-asyncio loop was stopped before its async " + "generators were finalized; weird stuff might happen", + RuntimeWarning, + ) + finally: + with trio.CancelScope(shield=True): + await asyncgens_done.wait() + await loop._main_loop_exit() finally: loop.close() current_loop.reset(old_loop)