From c04af39a97169485e5bfae43446bcc890b4693ed Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Fri, 18 Oct 2024 16:05:44 +0300 Subject: [PATCH] feat(fal): don't scale by default on deploy --- projects/fal/src/fal/api.py | 10 ++++++---- projects/fal/src/fal/cli/deploy.py | 9 +++++++++ projects/fal/src/fal/sdk.py | 2 ++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/projects/fal/src/fal/api.py b/projects/fal/src/fal/api.py index 1b39d3b1..8fac456e 100644 --- a/projects/fal/src/fal/api.py +++ b/projects/fal/src/fal/api.py @@ -171,6 +171,7 @@ def register( application_name: str | None = None, application_auth_mode: Literal["public", "shared", "private"] | None = None, metadata: dict[str, Any] | None = None, + scale: bool = True, ) -> str | None: """Register the given function on the host for API call execution.""" raise NotImplementedError @@ -430,6 +431,7 @@ def register( application_auth_mode: Literal["public", "shared", "private"] | None = None, metadata: dict[str, Any] | None = None, deployment_strategy: Literal["recreate", "rolling"] = "recreate", + scale: bool = True, ) -> str | None: environment_options = options.environment.copy() environment_options.setdefault("python_version", active_python()) @@ -439,15 +441,14 @@ def register( "machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE ) keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE) - max_concurrency = options.host.get("max_concurrency") - min_concurrency = options.host.get("min_concurrency") - max_multiplexing = options.host.get("max_multiplexing") base_image = options.host.get("_base_image", None) scheduler = options.host.get("_scheduler", None) scheduler_options = options.host.get("_scheduler_options", None) + max_concurrency = options.host.get("max_concurrency") + min_concurrency = options.host.get("min_concurrency") + max_multiplexing = options.host.get("max_multiplexing") exposed_port = options.get_exposed_port() request_timeout = options.host.get("request_timeout") - machine_requirements = MachineRequirements( machine_types=machine_type, # type: ignore num_gpus=options.host.get("num_gpus"), @@ -486,6 +487,7 @@ def register( machine_requirements=machine_requirements, metadata=metadata, deployment_strategy=deployment_strategy, + scale=scale, ): for log in partial_result.logs: self._log_printer.print(log) diff --git a/projects/fal/src/fal/cli/deploy.py b/projects/fal/src/fal/cli/deploy.py index 4afdb25b..40febf42 100644 --- a/projects/fal/src/fal/cli/deploy.py +++ b/projects/fal/src/fal/cli/deploy.py @@ -106,6 +106,7 @@ def _deploy_from_reference( application_auth_mode=app_auth, metadata=isolated_function.options.host.get("metadata", {}), deployment_strategy=deployment_strategy, + scale=not args.no_scale, ) if app_id: @@ -219,5 +220,13 @@ def valid_auth_option(option): help="Deployment strategy.", default="recreate", ) + parser.add_argument( + "--no-scale", + action="store_true", + help=( + "Use min_concurrency/max_concurrency/max_multiplexing from previous " + "deployment of application with this name." + ), + ) parser.set_defaults(func=_deploy) diff --git a/projects/fal/src/fal/sdk.py b/projects/fal/src/fal/sdk.py index 7f5768b4..3a90b67d 100644 --- a/projects/fal/src/fal/sdk.py +++ b/projects/fal/src/fal/sdk.py @@ -497,6 +497,7 @@ def register( machine_requirements: MachineRequirements | None = None, metadata: dict[str, Any] | None = None, deployment_strategy: Literal["recreate", "rolling"] = "recreate", + scale: bool = True, ) -> Iterator[isolate_proto.RegisterApplicationResult]: wrapped_function = to_serialized_object(function, serialization_method) if machine_requirements: @@ -544,6 +545,7 @@ def register( auth_mode=auth_mode, metadata=struct_metadata, deployment_strategy=deployment_strategy_proto, + scale=scale, ) for partial_result in self.stub.RegisterApplication(request): yield from_grpc(partial_result)