Skip to content

Commit

Permalink
feat: add s3 repository (#339)
Browse files Browse the repository at this point in the history
* feat: add s3 repository

* chore: revert default repo change
  • Loading branch information
badayvedat authored Oct 21, 2024
1 parent d191413 commit c885032
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 2 deletions.
18 changes: 16 additions & 2 deletions projects/fal/src/fal/toolkit/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,24 @@ def from_bytes(
FileRepository | RepositoryId
] = FALLBACK_REPOSITORY,
request: Optional[Request] = None,
save_kwargs: Optional[dict] = None,
fallback_save_kwargs: Optional[dict] = None,
) -> File:
repo = (
repository
if isinstance(repository, FileRepository)
else get_builtin_repository(repository)
)

save_kwargs = save_kwargs or {}
fallback_save_kwargs = fallback_save_kwargs or {}

fdata = FileData(data, content_type, file_name)

object_lifecycle_preference = get_lifecycle_preference(request)

try:
url = repo.save(fdata, object_lifecycle_preference)
url = repo.save(fdata, object_lifecycle_preference, **save_kwargs)
except Exception:
if not fallback_repository:
raise
Expand All @@ -158,7 +163,9 @@ def from_bytes(
else get_builtin_repository(fallback_repository)
)

url = fallback_repo.save(fdata, object_lifecycle_preference)
url = fallback_repo.save(
fdata, object_lifecycle_preference, **fallback_save_kwargs
)

return cls(
url=url,
Expand All @@ -179,6 +186,8 @@ def from_path(
FileRepository | RepositoryId
] = FALLBACK_REPOSITORY,
request: Optional[Request] = None,
save_kwargs: Optional[dict] = None,
fallback_save_kwargs: Optional[dict] = None,
) -> File:
file_path = Path(path)
if not file_path.exists():
Expand All @@ -190,6 +199,9 @@ def from_path(
else get_builtin_repository(repository)
)

save_kwargs = save_kwargs or {}
fallback_save_kwargs = fallback_save_kwargs or {}

content_type = content_type or "application/octet-stream"
object_lifecycle_preference = get_lifecycle_preference(request)

Expand All @@ -199,6 +211,7 @@ def from_path(
content_type=content_type,
multipart=multipart,
object_lifecycle_preference=object_lifecycle_preference,
**save_kwargs,
)
except Exception:
if not fallback_repository:
Expand All @@ -215,6 +228,7 @@ def from_path(
content_type=content_type,
multipart=multipart,
object_lifecycle_preference=object_lifecycle_preference,
**fallback_save_kwargs,
)

return cls(
Expand Down
80 changes: 80 additions & 0 deletions projects/fal/src/fal/toolkit/file/providers/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import os
import posixpath
import uuid
from dataclasses import dataclass
from io import BytesIO
from typing import Optional

from fal.toolkit.file.types import FileData, FileRepository
from fal.toolkit.utils.retry import retry

DEFAULT_URL_TIMEOUT = 60 * 15 # 15 minutes


@dataclass
class S3Repository(FileRepository):
bucket_name: str = "fal_file_storage"
url_expiration: int = DEFAULT_URL_TIMEOUT
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None

_s3_client = None

def __post_init__(self):
try:
import boto3
from botocore.client import Config
except ImportError:
raise Exception("boto3 is not installed")

if self.aws_access_key_id is None:
self.aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
if self.aws_access_key_id is None:
raise Exception("AWS_ACCESS_KEY_ID environment variable is not set")

if self.aws_secret_access_key is None:
self.aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
if self.aws_secret_access_key is None:
raise Exception("AWS_SECRET_ACCESS_KEY environment variable is not set")

self._s3_client = boto3.client(
"s3",
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
config=Config(signature_version="s3v4"),
)

@property
def storage_client(self):
if self._s3_client is None:
raise Exception("S3 client is not initialized")

return self._s3_client

@retry(max_retries=3, base_delay=1, backoff_type="exponential", jitter=True)
def save(
self,
data: FileData,
object_lifecycle_preference: Optional[dict[str, str]] = None,
key: Optional[str] = None,
) -> str:
destination_path = posixpath.join(
key or "",
f"{uuid.uuid4().hex}_{data.file_name}",
)

self.storage_client.upload_fileobj(
BytesIO(data.data),
self.bucket_name,
destination_path,
ExtraArgs={"ContentType": data.content_type},
)

public_url = self.storage_client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket_name, "Key": destination_path},
ExpiresIn=self.url_expiration,
)
return public_url

0 comments on commit c885032

Please sign in to comment.