Skip to content

Commit

Permalink
[replit_river] return cleanup task from client.disconnect()
Browse files Browse the repository at this point in the history
Why
===
* The task created by the websocket wrapper was orphaned.
* Since it takes a while to finish, we don't want to wait for it in
certain cases.

What changed
===
* When the websocket close task is made, return it.
* At every level, return the task and combine it with other cleanup
tasks as appropriate

Test plan
===
* The behavior shouldn't be different unless you await the cleanup
function.
  • Loading branch information
ryantm committed Oct 11, 2024
1 parent 80e6d95 commit 655583e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 11 deletions.
6 changes: 4 additions & 2 deletions replit_river/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from typing import Any, Generic, Optional, Union
Expand Down Expand Up @@ -38,10 +39,11 @@ def __init__(
transport_options=transport_options,
)

async def close(self) -> None:
async def close(self) -> asyncio.Task | None:
logger.info(f"river client {self._client_id} start closing")
await self._transport.close()
cleanup_task = await self._transport.close()
logger.info(f"river client {self._client_id} closed")
return cleanup_task

async def ensure_connected(self) -> None:
await self._transport.get_or_create_session()
Expand Down
4 changes: 2 additions & 2 deletions replit_river/client_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(
# We want to make sure there's only one session creation at a time
self._create_session_lock = asyncio.Lock()

async def close(self) -> None:
async def close(self) -> asyncio.Task:
self._rate_limiter.close()
await self._close_all_sessions()
return await self._close_all_sessions()

async def get_or_create_session(self) -> ClientSession:
async with self._create_session_lock:
Expand Down
15 changes: 11 additions & 4 deletions replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,15 +434,18 @@ async def _send_responses_from_output_stream(

async def close_websocket(
self, ws_wrapper: WebsocketWrapper, should_retry: bool
) -> None:
) -> asyncio.Task | None:
"""Mark the websocket as closed, close the websocket, and retry if needed."""
cleanup_websocket_task: asyncio.Task | None = None
async with self._ws_lock:
# Already closed.
if not await ws_wrapper.is_open():
logger.info("websocket wrapper already closed")
return
await ws_wrapper.close()
cleanup_websocket_task = await ws_wrapper.close()
if should_retry and self._retry_connection_callback:
self._task_manager.create_task(self._retry_connection_callback())
return cleanup_websocket_task

async def _open_stream_and_call_handler(
self,
Expand Down Expand Up @@ -523,8 +526,9 @@ async def _remove_acked_messages_in_buffer(self) -> None:
async def start_serve_responses(self) -> None:
self._task_manager.create_task(self.serve())

async def close(self) -> None:
async def close(self) -> asyncio.Task | None:
"""Close the session and all associated streams."""
cleanup_websocket_task: asyncio.Task | None = None
logger.info(
f"{self._transport_id} closing session "
f"to {self._to_id}, ws: {self._ws_wrapper.id}, "
Expand All @@ -538,7 +542,9 @@ async def close(self) -> None:
self._reset_session_close_countdown()
await self._task_manager.cancel_all_tasks()

await self.close_websocket(self._ws_wrapper, should_retry=False)
cleanup_websocket_task = await self.close_websocket(
self._ws_wrapper, should_retry=False
)

await self._buffer.close()

Expand All @@ -553,3 +559,4 @@ async def close(self) -> None:
self._streams.clear()

self._state = SessionState.CLOSED
return cleanup_websocket_task
9 changes: 7 additions & 2 deletions replit_river/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(
self._handlers: Dict[Tuple[str, str], Tuple[str, GenericRpcHandler]] = {}
self._session_lock = asyncio.Lock()

async def _close_all_sessions(self) -> None:
async def _close_all_sessions(self) -> asyncio.Task:
cleanup_tasks: list[asyncio.Task] = []
sessions = self._sessions.values()
logger.info(
f"start closing sessions {self._transport_id}, number sessions : "
Expand All @@ -38,10 +39,14 @@ async def _close_all_sessions(self) -> None:
# closing sessions requires access to the session lock, so we need to close
# them one by one to be safe
for session in sessions_to_close:
await session.close()
cleanup_task = await session.close()
if cleanup_task:
cleanup_tasks.append(cleanup_task)

logger.info(f"Transport closed {self._transport_id}")

return asyncio.gather(*cleanup_tasks)

async def _delete_session(self, session: Session) -> None:
async with self._session_lock:
if session._to_id in self._sessions:
Expand Down
3 changes: 2 additions & 1 deletion replit_river/websocket_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def is_open(self) -> bool:
async with self.ws_lock:
return self.ws_state == WsState.OPEN

async def close(self) -> None:
async def close(self) -> asyncio.Task | None:
async with self.ws_lock:
if self.ws_state == WsState.OPEN:
self.ws_state = WsState.CLOSING
Expand All @@ -33,3 +33,4 @@ async def close(self) -> None:
lambda _: logger.debug("old websocket %s closed.", self.ws.id)
)
self.ws_state = WsState.CLOSED
return task

0 comments on commit 655583e

Please sign in to comment.