diff --git a/projects/fal/src/fal/toolkit/file/file.py b/projects/fal/src/fal/toolkit/file/file.py index 20783651..0206630c 100644 --- a/projects/fal/src/fal/toolkit/file/file.py +++ b/projects/fal/src/fal/toolkit/file/file.py @@ -135,6 +135,8 @@ def from_bytes( FileRepository | RepositoryId ] = FALLBACK_REPOSITORY, request: Optional[Request] = None, + save_kwargs: Optional[dict] = None, + fallback_save_kwargs: Optional[dict] = None, ) -> File: repo = ( repository @@ -142,12 +144,15 @@ def from_bytes( else get_builtin_repository(repository) ) + save_kwargs = save_kwargs or {} + fallback_save_kwargs = fallback_save_kwargs or {} + fdata = FileData(data, content_type, file_name) object_lifecycle_preference = get_lifecycle_preference(request) try: - url = repo.save(fdata, object_lifecycle_preference) + url = repo.save(fdata, object_lifecycle_preference, **save_kwargs) except Exception: if not fallback_repository: raise @@ -158,7 +163,9 @@ def from_bytes( else get_builtin_repository(fallback_repository) ) - url = fallback_repo.save(fdata, object_lifecycle_preference) + url = fallback_repo.save( + fdata, object_lifecycle_preference, **fallback_save_kwargs + ) return cls( url=url, @@ -179,6 +186,8 @@ def from_path( FileRepository | RepositoryId ] = FALLBACK_REPOSITORY, request: Optional[Request] = None, + save_kwargs: Optional[dict] = None, + fallback_save_kwargs: Optional[dict] = None, ) -> File: file_path = Path(path) if not file_path.exists(): @@ -190,6 +199,9 @@ def from_path( else get_builtin_repository(repository) ) + save_kwargs = save_kwargs or {} + fallback_save_kwargs = fallback_save_kwargs or {} + content_type = content_type or "application/octet-stream" object_lifecycle_preference = get_lifecycle_preference(request) @@ -199,6 +211,7 @@ def from_path( content_type=content_type, multipart=multipart, object_lifecycle_preference=object_lifecycle_preference, + **save_kwargs, ) except Exception: if not fallback_repository: @@ -215,6 +228,7 @@ def from_path( content_type=content_type, multipart=multipart, object_lifecycle_preference=object_lifecycle_preference, + **fallback_save_kwargs, ) return cls( diff --git a/projects/fal/src/fal/toolkit/file/providers/s3.py b/projects/fal/src/fal/toolkit/file/providers/s3.py new file mode 100644 index 00000000..b60f2296 --- /dev/null +++ b/projects/fal/src/fal/toolkit/file/providers/s3.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import os +import posixpath +import uuid +from dataclasses import dataclass +from io import BytesIO +from typing import Optional + +from fal.toolkit.file.types import FileData, FileRepository +from fal.toolkit.utils.retry import retry + +DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes + + +@dataclass +class S3Repository(FileRepository): + bucket_name: str = "fal_file_storage" + url_expiration: int = DEFAULT_URL_TIMEOUT + aws_access_key_id: str | None = None + aws_secret_access_key: str | None = None + + _s3_client = None + + def __post_init__(self): + try: + import boto3 + from botocore.client import Config + except ImportError: + raise Exception("boto3 is not installed") + + if self.aws_access_key_id is None: + self.aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") + if self.aws_access_key_id is None: + raise Exception("AWS_ACCESS_KEY_ID environment variable is not set") + + if self.aws_secret_access_key is None: + self.aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") + if self.aws_secret_access_key is None: + raise Exception("AWS_SECRET_ACCESS_KEY environment variable is not set") + + self._s3_client = boto3.client( + "s3", + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + config=Config(signature_version="s3v4"), + ) + + @property + def storage_client(self): + if self._s3_client is None: + raise Exception("S3 client is not initialized") + + return self._s3_client + + @retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True) + def save( + self, + data: FileData, + object_lifecycle_preference: Optional[dict[str, str]] = None, + key: Optional[str] = None, + ) -> str: + destination_path = posixpath.join( + key or "", + f"{uuid.uuid4().hex}_{data.file_name}", + ) + + self.storage_client.upload_fileobj( + BytesIO(data.data), + self.bucket_name, + destination_path, + ExtraArgs={"ContentType": data.content_type}, + ) + + public_url = self.storage_client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": self.bucket_name, "Key": destination_path}, + ExpiresIn=self.url_expiration, + ) + return public_url