Skip to content

Commit

Permalink
feat(fal_client): use cdn v3 (#344)
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop authored Oct 23, 2024
1 parent 033d19f commit 2492e4b
Showing 1 changed file with 155 additions and 18 deletions.
173 changes: 155 additions & 18 deletions projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import asyncio
import time
import base64
import threading
from datetime import datetime, timezone
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, AsyncIterator, Iterator, TYPE_CHECKING, Optional, Literal
Expand All @@ -23,10 +25,96 @@
RUN_URL_FORMAT = f"https://{FAL_RUN_HOST}/"
QUEUE_URL_FORMAT = f"https://queue.{FAL_RUN_HOST}/"
REALTIME_URL_FORMAT = f"wss://{FAL_RUN_HOST}/"
CDN_URL = "https://fal.media"
REST_URL = "https://rest.alpha.fal.ai"
CDN_URL = "https://v3.fal.media"
USER_AGENT = "fal-client/0.2.2 (python)"


@dataclass
class CDNToken:
token: str
token_type: str
base_upload_url: str
expires_at: datetime

def is_expired(self) -> bool:
return datetime.now(timezone.utc) >= self.expires_at


class CDNTokenManager:
def __init__(self, key: str) -> None:
self._key = key
self._token: CDNToken = CDNToken(
token="",
token_type="",
base_upload_url="",
expires_at=datetime.min.replace(tzinfo=timezone.utc),
)
self._lock: threading.Lock = threading.Lock()
self._url = f"{REST_URL}/storage/auth/token?storage_type=fal-cdn-v3"
self._headers = {
"Authorization": f"Key {self._key}",
"Accept": "application/json",
"Content-Type": "application/json",
}

def _refresh_token(self) -> CDNToken:
with httpx.Client() as client:
response = client.post(self._url, headers=self._headers, data=b"{}")
response.raise_for_status()
data = response.json()

return CDNToken(
token=data["token"],
token_type=data["token_type"],
base_upload_url=data["base_url"],
expires_at=datetime.fromisoformat(data["expires_at"]),
)

def get_token(self) -> CDNToken:
with self._lock:
if self._token.is_expired():
self._token = self._refresh_token()
return self._token


class AsyncCDNTokenManager:
def __init__(self, key: str) -> None:
self._key = key
self._token: CDNToken = CDNToken(
token="",
token_type="",
base_upload_url="",
expires_at=datetime.min.replace(tzinfo=timezone.utc),
)
self._lock: threading.Lock = threading.Lock()
self._url = f"{REST_URL}/storage/auth/token?storage_type=fal-cdn-v3"
self._headers = {
"Authorization": f"Key {self._key}",
"Accept": "application/json",
"Content-Type": "application/json",
}

async def _refresh_token(self) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(self._url, headers=self._headers, data=b"{}")
response.raise_for_status()
data = response.json()

return CDNToken(
token=data["token"],
token_type=data["token_type"],
base_upload_url=data["base_url"],
expires_at=datetime.fromisoformat(data["expires_at"]),
)

async def get_token(self) -> CDNToken:
with self._lock:
if self._token.is_expired():
self._token = await self._refresh_token()
return self._token


class FalClientError(Exception):
pass

Expand Down Expand Up @@ -265,13 +353,18 @@ class AsyncClient:
key: str | None = field(default=None, repr=False)
default_timeout: float = 120.0

@cached_property
def _client(self) -> httpx.AsyncClient:
def _get_key(self) -> str:
if self.key is None:
key = fetch_credentials()
else:
key = self.key
return fetch_credentials()
return self.key

@cached_property
def _token_manager(self) -> AsyncCDNTokenManager:
return AsyncCDNTokenManager(self._get_key())

@cached_property
def _client(self) -> httpx.AsyncClient:
key = self._get_key()
return httpx.AsyncClient(
headers={
"Authorization": f"Key {key}",
Expand All @@ -280,6 +373,16 @@ def _client(self) -> httpx.AsyncClient:
timeout=self.default_timeout,
)

async def _get_cdn_client(self) -> httpx.AsyncClient:
token = await self._token_manager.get_token()
return httpx.AsyncClient(
headers={
"Authorization": f"{token.token_type} {token.token}",
"User-Agent": USER_AGENT,
},
timeout=self.default_timeout,
)

async def run(
self,
application: str,
Expand Down Expand Up @@ -425,14 +528,22 @@ async def stream(
async for event in events.aiter_sse():
yield event.json()

async def upload(self, data: str | bytes, content_type: str) -> str:
async def upload(
self, data: str | bytes, content_type: str, file_name: str | None = None
) -> str:
"""Upload the given data blob to the CDN and return the access URL. The content type should be specified
as the second argument. Use upload_file or upload_image for convenience."""

response = await self._client.post(
client = await self._get_cdn_client()

headers = {"Content-Type": content_type}
if file_name is not None:
headers["X-Fal-File-Name"] = file_name

response = await client.post(
CDN_URL + "/files/upload",
data=data,
headers={"Content-Type": content_type},
headers=headers,
)
_raise_for_status(response)

Expand All @@ -446,7 +557,9 @@ async def upload_file(self, path: os.PathLike) -> str:
mime_type = "application/octet-stream"

with open(path, "rb") as file:
return await self.upload(file.read(), mime_type)
return await self.upload(
file.read(), mime_type, file_name=os.path.basename(path)
)

async def upload_image(self, image: Image.Image, format: str = "jpeg") -> str:
"""Upload a pillow image object to the CDN and return the access URL."""
Expand All @@ -461,12 +574,14 @@ class SyncClient:
key: str | None = field(default=None, repr=False)
default_timeout: float = 120.0

def _get_key(self) -> str:
if self.key is None:
return fetch_credentials()
return self.key

@cached_property
def _client(self) -> httpx.Client:
if self.key is None:
key = fetch_credentials()
else:
key = self.key
key = self._get_key()
return httpx.Client(
headers={
"Authorization": f"Key {key}",
Expand All @@ -475,6 +590,20 @@ def _client(self) -> httpx.Client:
timeout=self.default_timeout,
)

@cached_property
def _token_manager(self) -> CDNTokenManager:
return CDNTokenManager(self._get_key())

def _get_cdn_client(self) -> httpx.Client:
token = self._token_manager.get_token()
return httpx.Client(
headers={
"Authorization": f"{token.token_type} {token.token}",
"User-Agent": USER_AGENT,
},
timeout=self.default_timeout,
)

def run(
self,
application: str,
Expand Down Expand Up @@ -617,14 +746,22 @@ def stream(
for event in events.iter_sse():
yield event.json()

def upload(self, data: str | bytes, content_type: str) -> str:
def upload(
self, data: str | bytes, content_type: str, file_name: str | None = None
) -> str:
"""Upload the given data blob to the CDN and return the access URL. The content type should be specified
as the second argument. Use upload_file or upload_image for convenience."""

response = self._client.post(
client = self._get_cdn_client()

headers = {"Content-Type": content_type}
if file_name is not None:
headers["X-Fal-File-Name"] = file_name

response = client.post(
CDN_URL + "/files/upload",
data=data,
headers={"Content-Type": content_type},
headers=headers,
)
_raise_for_status(response)

Expand All @@ -638,7 +775,7 @@ def upload_file(self, path: os.PathLike) -> str:
mime_type = "application/octet-stream"

with open(path, "rb") as file:
return self.upload(file.read(), mime_type)
return self.upload(file.read(), mime_type, file_name=os.path.basename(path))

def upload_image(self, image: Image.Image, format: str = "jpeg") -> str:
"""Upload a pillow image object to the CDN and return the access URL."""
Expand Down

0 comments on commit 2492e4b

Please sign in to comment.