Skip to content

Commit

Permalink
feat(client): add subscribe/status/result (#326)
Browse files Browse the repository at this point in the history
* feat(client): add status/result

* feat(client): add subscribe

* don't forget top level methods
  • Loading branch information
efiop authored Oct 7, 2024
1 parent fd07469 commit 5e42ae4
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 0 deletions.
6 changes: 6 additions & 0 deletions projects/fal_client/src/fal_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,21 @@

sync_client = SyncClient()
run = sync_client.run
subscribe = sync_client.subscribe
submit = sync_client.submit
status = sync_client.status
result = sync_client.result
stream = sync_client.stream
upload = sync_client.upload
upload_file = sync_client.upload_file
upload_image = sync_client.upload_image

async_client = AsyncClient()
run_async = async_client.run
subscribe_async = async_client.subscribe
submit_async = async_client.submit
status_async = async_client.status
result_async = async_client.result
stream_async = async_client.stream
upload_async = async_client.upload
upload_file_async = async_client.upload_file
Expand Down
64 changes: 64 additions & 0 deletions projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,41 @@ async def submit(
client=self._client,
)

async def subscribe(
self,
application: str,
arguments: AnyJSON,
*,
path: str = "",
hint: str | None = None,
with_logs: bool = False,
on_enqueue: Optional[callable[[Queued], None]] = None,
on_queue_update: Optional[callable[[Status], None]] = None,
) -> AnyJSON:
handle = await self.submit(application, arguments, path=path, hint=hint)

if on_enqueue is not None:
on_enqueue(handle.request_id)

if on_queue_update is not None:
async for event in handle.iter_events(with_logs=with_logs):
on_queue_update(event)

return await handle.get()

def get_handle(self, application: str, request_id: str) -> AsyncRequestHandle:
return AsyncRequestHandle.from_request_id(self._client, application, request_id)

async def status(
self, application: str, request_id: str, *, with_logs: bool = False
) -> Status:
handle = self.get_handle(application, request_id)
return await handle.status(with_logs=with_logs)

async def result(self, application: str, request_id: str) -> AnyJSON:
handle = self.get_handle(application, request_id)
return await handle.get()

async def stream(
self,
application: str,
Expand Down Expand Up @@ -494,9 +526,41 @@ def submit(
client=self._client,
)

def subscribe(
self,
application: str,
arguments: AnyJSON,
*,
path: str = "",
hint: str | None = None,
with_logs: bool = False,
on_enqueue: Optional[callable[[Queued], None]] = None,
on_queue_update: Optional[callable[[Status], None]] = None,
) -> AnyJSON:
handle = self.submit(application, arguments, path=path, hint=hint)

if on_enqueue is not None:
on_enqueue(handle.request_id)

if on_queue_update is not None:
for event in handle.iter_events(with_logs=with_logs):
on_queue_update(event)

return handle.get()

def get_handle(self, application: str, request_id: str) -> SyncRequestHandle:
return SyncRequestHandle.from_request_id(self._client, application, request_id)

def status(
self, application: str, request_id: str, *, with_logs: bool = False
) -> Status:
handle = self.get_handle(application, request_id)
return handle.status(with_logs=with_logs)

def result(self, application: str, request_id: str) -> AnyJSON:
handle = self.get_handle(application, request_id)
return handle.get()

def stream(
self,
application: str,
Expand Down
22 changes: 22 additions & 0 deletions projects/fal_client/tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ async def test_fal_client(client: fal_client.AsyncClient):
result = await handle.get()
assert result["seed"] == 42

assert (
await client.result("fal-ai/fast-sdxl/image-to-image", handle.request_id)
== result
)

status = await handle.status(with_logs=False)
assert isinstance(status, fal_client.Completed)
assert status.logs is None
Expand All @@ -42,6 +47,23 @@ async def test_fal_client(client: fal_client.AsyncClient):
assert isinstance(status_w_logs, fal_client.Completed)
assert status_w_logs.logs is not None

assert (
await client.status(
"fal-ai/fast-sdxl/image-to-image",
handle.request_id,
)
== status
)

output = await client.subscribe(
"fal-ai/fast-sdxl",
arguments={
"prompt": "a cat",
},
hint="lora:a",
)
assert len(output["images"]) == 1

output = await client.run(
"fal-ai/fast-sdxl",
arguments={
Expand Down
13 changes: 13 additions & 0 deletions projects/fal_client/tests/test_sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def test_fal_client(client: fal_client.SyncClient):
result = handle.get()
assert result["seed"] == 42

assert client.result("fal-ai/fast-sdxl/image-to-image", handle.request_id) == result

status = handle.status(with_logs=False)
assert isinstance(status, fal_client.Completed)
assert status.logs is None
Expand All @@ -42,6 +44,17 @@ def test_fal_client(client: fal_client.SyncClient):
new_handle = client.get_handle("fal-ai/fast-sdxl/image-to-image", handle.request_id)
assert new_handle == handle

assert client.status("fal-ai/fast-sdxl/image-to-image", handle.request_id) == status

output = client.subscribe(
"fal-ai/fast-sdxl",
arguments={
"prompt": "a cat",
},
hint="lora:a",
)
assert len(output["images"]) == 1

output = client.run(
"fal-ai/fast-sdxl",
arguments={
Expand Down

0 comments on commit 5e42ae4

Please sign in to comment.