From f8d71e1d3130c84b900d3db464bef42b1796d65b Mon Sep 17 00:00:00 2001 From: turbo1921 Date: Mon, 29 Apr 2024 19:01:23 -0700 Subject: [PATCH] feat: automatically pass X-Fal-Object-Lifecycle-Preference as a header --- projects/fal/src/fal/app.py | 19 +++++++++++++++ .../fal/src/fal/toolkit/file/providers/fal.py | 24 +++++++++++++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/projects/fal/src/fal/app.py b/projects/fal/src/fal/app.py index f2ee4e03..14c0be89 100644 --- a/projects/fal/src/fal/app.py +++ b/projects/fal/src/fal/app.py @@ -13,6 +13,7 @@ from fal.api import RouteSignature from fal.logging import get_logger from fal._serialization import include_modules_from + REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"] EndpointT = TypeVar("EndpointT", bound=Callable[..., Any]) @@ -123,6 +124,24 @@ async def provide_hints_headers(request, call_next): ) return response + @app.middleware("http") + async def set_global_object_preference(request, call_next): + response = await call_next(request) + try: + from fal.toolkit.file.providers import fal + + fal.GLOBAL_LIFECYCLE_PREFERENCE = request.headers[ + "X-Fal-Object-Lifecycle-Preference" + ] + except Exception: + from fastapi.logger import logger + + logger.exception( + "Failed set a global lifecycle preference %s", + self.__class__.__name__, + ) + return response + def provide_hints(self) -> list[str]: """Provide hints for routing the application.""" raise NotImplementedError diff --git a/projects/fal/src/fal/toolkit/file/providers/fal.py b/projects/fal/src/fal/toolkit/file/providers/fal.py index 86a6a067..9761c936 100644 --- a/projects/fal/src/fal/toolkit/file/providers/fal.py +++ b/projects/fal/src/fal/toolkit/file/providers/fal.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import json import os from base64 import b64encode @@ -14,6 +15,16 @@ _FAL_CDN = "https://fal.media" +@dataclass +class ObjectLifecyclePreference: + expriation_days: datetime.timedelta + + +GLOBAL_LIFECYCLE_PREFERENCE = ObjectLifecyclePreference( + expriation_days=datetime.timedelta(days=2) +) + + @dataclass class FalFileRepository(FileRepository): def save(self, file: FileData) -> str: @@ -70,17 +81,26 @@ def _upload_file(self, upload_url: str, file: FileData): @dataclass class InMemoryRepository(FileRepository): - def save(self, file: FileData) -> str: + def save( + self, + file: FileData, + ) -> str: return f'data:{file.content_type};base64,{b64encode(file.data).decode("utf-8")}' @dataclass class FalCDNFileRepository(FileRepository): - def save(self, file: FileData) -> str: + def save( + self, + file: FileData, + ) -> str: headers = { **self.auth_headers, "Accept": "application/json", "Content-Type": file.content_type, + "X-Fal-Object-Lifecycle-Preference": json.dumps( + GLOBAL_LIFECYCLE_PREFERENCE + ), } url = os.getenv("FAL_CDN_HOST", _FAL_CDN) + "/files/upload"