Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(container): Add initial APIs #197

Merged
merged 6 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ dmypy.json

/coverage.xml
/.coverage
tmp/
2 changes: 2 additions & 0 deletions projects/fal/src/fal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fal.api import FalServerlessHost, LocalHost, cached, function
from fal.api import function as isolated # noqa: F401
from fal.app import App, endpoint, realtime, wrap_app # noqa: F401
from fal.container import ContainerImage
from fal.sdk import FalServerlessKeyCredentials
from fal.sync import sync_dir

Expand All @@ -26,4 +27,5 @@
"sync_dir",
"__version__",
"version_tuple",
"ContainerImage",
]
58 changes: 54 additions & 4 deletions projects/fal/src/fal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

import fal.flags as flags
from fal._serialization import include_modules_from, patch_pickle
from fal.container import ContainerImage
from fal.exceptions import FalServerlessException
from fal.logging.isolate import IsolateLogPrinter
from fal.sdk import (
Expand Down Expand Up @@ -523,9 +524,12 @@ def add_requirements(self, requirements: list[str]):
pip_requirements = self.environment.setdefault("requirements", [])
elif kind == "conda":
pip_requirements = self.environment.setdefault("pip", [])
elif kind == "container":
return None
else:
raise FalServerlessError(
"Only conda and virtualenv is supported as environment options"
"Only {conda, virtualenv, container} "
"are supported as environment options."
)

# Already has these.
Expand Down Expand Up @@ -743,8 +747,55 @@ def function(
_scheduler: str | None = None,
) -> Callable[
[Callable[Concatenate[ArgsT], ReturnT]], ServedIsolatedFunction[ArgsT, ReturnT]
]:
...
]: ...


@overload
def function(
kind: Literal["container"],
*,
image: ContainerImage | None = None,
# Common options
host: FalServerlessHost = _DEFAULT_HOST,
serve: Literal[False] = False,
exposed_port: int | None = None,
max_concurrency: int | None = None,
# FalServerlessHost options
metadata: dict[str, Any] | None = None,
machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
setup_function: Callable[..., None] | None = None,
_base_image: str | None = None,
_scheduler: str | None = None,
) -> Callable[
[Callable[Concatenate[ArgsT], ReturnT]], IsolatedFunction[ArgsT, ReturnT]
]: ...


@overload
def function(
kind: Literal["container"],
*,
image: ContainerImage | None = None,
# Common options
host: FalServerlessHost = _DEFAULT_HOST,
serve: Literal[True],
exposed_port: int | None = None,
max_concurrency: int | None = None,
# FalServerlessHost options
metadata: dict[str, Any] | None = None,
machine_type: str = FAL_SERVERLESS_DEFAULT_MACHINE_TYPE,
keep_alive: int = FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING,
min_concurrency: int = FAL_SERVERLESS_DEFAULT_MIN_CONCURRENCY,
setup_function: Callable[..., None] | None = None,
_base_image: str | None = None,
_scheduler: str | None = None,
) -> Callable[
[Callable[Concatenate[ArgsT], ReturnT]], ServedIsolatedFunction[ArgsT, ReturnT]
]: ...


# implementation
Expand Down Expand Up @@ -1121,4 +1172,3 @@ class Server(uvicorn.Server):

def install_signal_handlers(self) -> None:
pass

19 changes: 19 additions & 0 deletions projects/fal/src/fal/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class ContainerImage:
"""ContainerImage represents a Docker image that can be built
from a Dockerfile.
"""

_known_keys = {"dockerfile_str", "build_env", "build_args"}

@classmethod
def from_dockerfile_str(cls, text: str, **kwargs):
# Check for unknown keys and return them as a dict.
return dict(
dockerfile_str=text,
**{k: v for k, v in kwargs.items() if k in cls._known_keys},
)

@classmethod
def from_dockerfile(cls, path: str, **kwargs):
with open(path) as fobj:
return cls.from_dockerfile_str(fobj.read(), **kwargs)
18 changes: 18 additions & 0 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import httpx
import pytest
from fal import apps
from fal.container import ContainerImage
from fal.rest_client import REST_CLIENT
from fal.workflows import Workflow
from fastapi import WebSocket
Expand Down Expand Up @@ -49,6 +50,23 @@ def addition_app(input: Input) -> Output:

nomad_addition_app = addition_app.on(_scheduler="nomad")

@fal.function(
kind="container",
image=ContainerImage.from_dockerfile_str("FROM python:3.11"),
keep_alive=60,
machine_type="S",
serve=True,
max_concurrency=1,
)
def container_addition_app(input: Input) -> Output:
print("starting...")
for _ in range(input.wait_time):
print("sleeping...")
time.sleep(1)

return Output(result=input.lhs + input.rhs)
efiop marked this conversation as resolved.
Show resolved Hide resolved


@fal.function(
keep_alive=300,
requirements=["fastapi", "uvicorn", "pydantic==1.10.12"],
Expand Down
35 changes: 35 additions & 0 deletions projects/fal/tests/test_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import fal
import pytest
from fal.api import FalServerlessError, Options
from fal.container import ContainerImage
from fal.toolkit.file.file import File
from pydantic import __version__ as pydantic_version

Expand Down Expand Up @@ -55,6 +56,40 @@ def mult(a, b):
assert mult(5, 2) == 10


def test_regular_function_in_a_container(isolated_client):
@isolated_client("container")
def regular_function():
return 42

assert regular_function() == 42

@isolated_client("container")
def mult(a, b):
return a * b

assert mult(5, 2) == 10


def test_regular_function_in_a_container_with_custom_image(isolated_client):
@isolated_client(
"container",
image=ContainerImage.from_dockerfile_str("FROM python:3.9"),
)
def regular_function():
return 42

assert regular_function() == 42

@isolated_client(
"container",
image=ContainerImage.from_dockerfile_str("FROM python:3.9"),
)
def mult(a, b):
return a * b

assert mult(5, 2) == 10


def test_function_pipelining(isolated_client):
@isolated_client("virtualenv")
def regular_function():
Expand Down
Loading