diff --git a/projects/fal_client/src/fal_client/__init__.py b/projects/fal_client/src/fal_client/__init__.py index 1d0802bb..4e4f253b 100644 --- a/projects/fal_client/src/fal_client/__init__.py +++ b/projects/fal_client/src/fal_client/__init__.py @@ -34,7 +34,10 @@ sync_client = SyncClient() run = sync_client.run +subscribe = sync_client.subscribe submit = sync_client.submit +status = sync_client.status +result = sync_client.result stream = sync_client.stream upload = sync_client.upload upload_file = sync_client.upload_file @@ -42,7 +45,10 @@ async_client = AsyncClient() run_async = async_client.run +subscribe_async = async_client.subscribe submit_async = async_client.submit +status_async = async_client.status +result_async = async_client.result stream_async = async_client.stream upload_async = async_client.upload upload_file_async = async_client.upload_file diff --git a/projects/fal_client/src/fal_client/client.py b/projects/fal_client/src/fal_client/client.py index d0dcca28..d1af783f 100644 --- a/projects/fal_client/src/fal_client/client.py +++ b/projects/fal_client/src/fal_client/client.py @@ -345,9 +345,41 @@ async def submit( client=self._client, ) + async def subscribe( + self, + application: str, + arguments: AnyJSON, + *, + path: str = "", + hint: str | None = None, + with_logs: bool = False, + on_enqueue: Optional[callable[[Queued], None]] = None, + on_queue_update: Optional[callable[[Status], None]] = None, + ) -> AnyJSON: + handle = await self.submit(application, arguments, path=path, hint=hint) + + if on_enqueue is not None: + on_enqueue(handle.request_id) + + if on_queue_update is not None: + async for event in handle.iter_events(with_logs=with_logs): + on_queue_update(event) + + return await handle.get() + def get_handle(self, application: str, request_id: str) -> AsyncRequestHandle: return AsyncRequestHandle.from_request_id(self._client, application, request_id) + async def status( + self, application: str, request_id: str, *, with_logs: bool = False + ) -> Status: + handle = self.get_handle(application, request_id) + return await handle.status(with_logs=with_logs) + + async def result(self, application: str, request_id: str) -> AnyJSON: + handle = self.get_handle(application, request_id) + return await handle.get() + async def stream( self, application: str, @@ -494,9 +526,41 @@ def submit( client=self._client, ) + def subscribe( + self, + application: str, + arguments: AnyJSON, + *, + path: str = "", + hint: str | None = None, + with_logs: bool = False, + on_enqueue: Optional[callable[[Queued], None]] = None, + on_queue_update: Optional[callable[[Status], None]] = None, + ) -> AnyJSON: + handle = self.submit(application, arguments, path=path, hint=hint) + + if on_enqueue is not None: + on_enqueue(handle.request_id) + + if on_queue_update is not None: + for event in handle.iter_events(with_logs=with_logs): + on_queue_update(event) + + return handle.get() + def get_handle(self, application: str, request_id: str) -> SyncRequestHandle: return SyncRequestHandle.from_request_id(self._client, application, request_id) + def status( + self, application: str, request_id: str, *, with_logs: bool = False + ) -> Status: + handle = self.get_handle(application, request_id) + return handle.status(with_logs=with_logs) + + def result(self, application: str, request_id: str) -> AnyJSON: + handle = self.get_handle(application, request_id) + return handle.get() + def stream( self, application: str, diff --git a/projects/fal_client/tests/test_async_client.py b/projects/fal_client/tests/test_async_client.py index 70fda2cb..a5a6bb30 100644 --- a/projects/fal_client/tests/test_async_client.py +++ b/projects/fal_client/tests/test_async_client.py @@ -31,6 +31,11 @@ async def test_fal_client(client: fal_client.AsyncClient): result = await handle.get() assert result["seed"] == 42 + assert ( + await client.result("fal-ai/fast-sdxl/image-to-image", handle.request_id) + == result + ) + status = await handle.status(with_logs=False) assert isinstance(status, fal_client.Completed) assert status.logs is None @@ -42,6 +47,23 @@ async def test_fal_client(client: fal_client.AsyncClient): assert isinstance(status_w_logs, fal_client.Completed) assert status_w_logs.logs is not None + assert ( + await client.status( + "fal-ai/fast-sdxl/image-to-image", + handle.request_id, + ) + == status + ) + + output = await client.subscribe( + "fal-ai/fast-sdxl", + arguments={ + "prompt": "a cat", + }, + hint="lora:a", + ) + assert len(output["images"]) == 1 + output = await client.run( "fal-ai/fast-sdxl", arguments={ diff --git a/projects/fal_client/tests/test_sync_client.py b/projects/fal_client/tests/test_sync_client.py index e89c4fdc..1f9f0bd4 100644 --- a/projects/fal_client/tests/test_sync_client.py +++ b/projects/fal_client/tests/test_sync_client.py @@ -31,6 +31,8 @@ def test_fal_client(client: fal_client.SyncClient): result = handle.get() assert result["seed"] == 42 + assert client.result("fal-ai/fast-sdxl/image-to-image", handle.request_id) == result + status = handle.status(with_logs=False) assert isinstance(status, fal_client.Completed) assert status.logs is None @@ -42,6 +44,17 @@ def test_fal_client(client: fal_client.SyncClient): new_handle = client.get_handle("fal-ai/fast-sdxl/image-to-image", handle.request_id) assert new_handle == handle + assert client.status("fal-ai/fast-sdxl/image-to-image", handle.request_id) == status + + output = client.subscribe( + "fal-ai/fast-sdxl", + arguments={ + "prompt": "a cat", + }, + hint="lora:a", + ) + assert len(output["images"]) == 1 + output = client.run( "fal-ai/fast-sdxl", arguments={