From 2a3fd2a1191f9a159222102c998fe1df7db16baa Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Tue, 20 Apr 2021 23:17:07 -0400 Subject: [PATCH] Allow the lifetime of the Connection thread to be tied to an event loop --- aiosqlite/core.py | 20 ++++++++++++++++++-- aiosqlite/tests/smoke.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/aiosqlite/core.py b/aiosqlite/core.py index dfb98b9..e14df02 100644 --- a/aiosqlite/core.py +++ b/aiosqlite/core.py @@ -47,6 +47,7 @@ def __init__( connector: Callable[[], sqlite3.Connection], iter_chunk_size: int, loop: Optional[asyncio.AbstractEventLoop] = None, + parent_loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: super().__init__() self._running = True @@ -54,6 +55,7 @@ def __init__( self._connector = connector self._tx: Queue = Queue() self._iter_chunk_size = iter_chunk_size + self._parent_loop = parent_loop if loop is not None: warn( @@ -87,7 +89,7 @@ def run(self) -> None: :meta private: """ - while True: + while self._parent_loop is None or not self._parent_loop.is_closed(): # Continues running until all queue items are processed, # even after connection is closed (so we can finalize all # futures) @@ -116,6 +118,19 @@ def set_exception(fut, e): get_loop(future).call_soon_threadsafe(set_exception, future, e) + # Clean up within this thread only if the parent event loop exits ungracefully + if not self._running or self._connection is None or self._parent_loop is None: + return + + try: + self._conn.close() + except Exception: + LOG.info("exception occurred while closing connection") + raise + finally: + self._running = False + self._connection = None + async def _execute(self, fn, *args, **kwargs): """Queue a function with the given arguments for execution.""" if not self._running or not self._connection: @@ -376,6 +391,7 @@ def connect( *, iter_chunk_size=64, loop: Optional[asyncio.AbstractEventLoop] = None, + parent_loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any ) -> Connection: """Create and return a connection proxy to the sqlite database.""" @@ -396,4 +412,4 @@ def connector() -> sqlite3.Connection: return sqlite3.connect(loc, **kwargs) - return Connection(connector, iter_chunk_size) + return Connection(connector, iter_chunk_size, parent_loop=parent_loop) diff --git a/aiosqlite/tests/smoke.py b/aiosqlite/tests/smoke.py index a7ee83f..a3d945c 100644 --- a/aiosqlite/tests/smoke.py +++ b/aiosqlite/tests/smoke.py @@ -3,6 +3,7 @@ import asyncio import sqlite3 import sys +import time from pathlib import Path from sqlite3 import OperationalError from threading import Thread @@ -465,3 +466,20 @@ async def test_backup_py36(self): ) as db2: with self.assertRaisesRegex(RuntimeError, "backup().+3.7"): await db1.backup(db2) + + async def test_no_close_with_parent_event_loop(self): + def runner(): + loop = asyncio.new_event_loop() + db = loop.run_until_complete(aiosqlite.connect(TEST_DB, parent_loop=loop)) + loop.close() + + # Wait long enough for the queue `get` timeout to elapse + time.sleep(0.2) + + # Database has been closed + with self.assertRaises(ValueError): + db.in_transaction + + thread = Thread(target=runner) + thread.start() + thread.join()