Skip to content

Commit

Permalink
fix: prepare tests for nomad scheduler (#189)
Browse files Browse the repository at this point in the history
* flaky

* test

* kube
  • Loading branch information
chamini2 authored Apr 25, 2024
1 parent aaac9ab commit 2bdddd4
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 30 deletions.
1 change: 1 addition & 0 deletions projects/fal/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ test = [
"pytest",
"pytest-xdist",
"pytest-asyncio",
"flaky",
]
dev = [
"fal[test]",
Expand Down
23 changes: 11 additions & 12 deletions projects/fal/tests/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@

import pytest
from pydantic import BaseModel, Field, __version__ as pydantic_version
from typing import Callable

import fal
from fal import FalServerlessHost, FalServerlessKeyCredentials, local, sync_dir
from fal.api import FalServerlessError
from fal.api import FalServerlessError, IsolatedFunction
from fal.toolkit import File, clone_repository, download_file, download_model_weights
from fal.toolkit.file.file import CompressedFile
from fal.toolkit.utils.download_utils import _get_git_revision_hash, _hash_url


@pytest.mark.flaky(max_runs=3)
def test_isolated(isolated_client):
def test_isolated(isolated_client: Callable[..., Callable[..., IsolatedFunction]]):
@isolated_client("virtualenv", requirements=["pyjokes==0.5.0"])
def get_pyjokes_version():
import pyjokes
Expand All @@ -30,24 +31,22 @@ def get_pyjokes_version():
def get_hostname() -> str:
import socket

return socket.gethostname()
hostname = socket.gethostname()
return hostname

import socket
local_hostname = socket.gethostname()

first = get_hostname()
assert first.startswith("worker")
assert local_hostname != first

get_hostname_local = get_hostname.on(local)
second = get_hostname_local()
assert not second.startswith("worker-")
assert local_hostname == second

get_hostname_m = get_hostname.on(machine_type="L")
third = get_hostname_m()
assert third.startswith("worker")
assert third != first

# The machine_type should be dropped when using local
get_hostname_m_local = get_hostname_m.on(local)
fourth = get_hostname_m_local()
assert not fourth.startswith("worker-")
assert local_hostname != third


def test_isolate_setup_funcs(isolated_client):
Expand Down
31 changes: 16 additions & 15 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,8 @@ def addition_app(input: Input) -> Output:
return Output(result=input.lhs + input.rhs)


@fal.function(
keep_alive=60,
machine_type="S",
serve=True,
max_concurrency=1,
requirements=[f"pydantic=={pydantic_version}"],
_scheduler="nomad",
)
def nomad_addition_app(input: Input) -> Output:
print("starting...")
for _ in range(input.wait_time):
print("sleeping...")
time.sleep(1)

return Output(result=input.lhs + input.rhs)
nomad_addition_app = addition_app.on(_scheduler="nomad")
kubernetes_addition_app = addition_app.on(_scheduler="kubernetes")


@fal.function(
Expand Down Expand Up @@ -217,6 +204,20 @@ def test_nomad_app():
yield f"{user_id}/{app_revision}"


@pytest.fixture(scope="module")
def test_kubernetes_app():
# Create a temporary app, register it, and return the ID of it.

from fal.cli import _get_user_id

app_revision = kubernetes_addition_app.host.register(
func=nomad_addition_app.func,
options=kubernetes_addition_app.options,
)
user_id = _get_user_id()
yield f"{user_id}/{app_revision}"


@pytest.fixture(scope="module")
def test_fastapi_app():
# Create a temporary app, register it, and return the ID of it.
Expand Down
34 changes: 31 additions & 3 deletions projects/fal/tests/test_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,30 @@ def regular_function(n):

def test_conda_environment(isolated_client):
@isolated_client(
"conda", packages=["pyjokes=0.6.0"], machine_type="L", resolver="conda"
"conda",
packages=["pyjokes=0.6.0"],
machine_type="L",
resolver="conda",
# conda is only supported on Kubernetes
_scheduler="kubernetes",
)
def regular_function():
import pyjokes

return pyjokes.__version__

assert regular_function() == "0.6.0"


@pytest.mark.xfail(reason="Nomad does not support conda")
def test_conda_environment_on_nomad(isolated_client):
@isolated_client(
"conda",
packages=["pyjokes=0.6.0"],
machine_type="L",
resolver="conda",
# conda is only supported on Kubernetes
_scheduler="nomad",
)
def regular_function():
import pyjokes
Expand Down Expand Up @@ -386,8 +409,13 @@ def factorial(n: int) -> int:
time.sleep(1) # slow CPU
return math.factorial(n)

# HACK: make this machine is not shared with others by using a unique requirements
@isolated_client("virtualenv", keep_alive=30, requirements=["pyjokes "])
# make sure this machine is not shared with others by using a unique requirements
@isolated_client(
"virtualenv",
keep_alive=30,
requirements=["pyjokes "],
_scheduler="kubernetes",
)
def regular_function(n):
if get_pipe() == "pipe":
return factorial(n)
Expand Down

0 comments on commit 2bdddd4

Please sign in to comment.