Skip to content

Commit

Permalink
feat: instrument served apps with Prometheus metrics
Browse files Browse the repository at this point in the history
This commit adds basic Prometheus instrumentation of the HTTP server of
all applications. This lets us conveniently monitor how models are
performing.

Signed-off-by: Lucas Servén Marín <[email protected]>
  • Loading branch information
squat committed Apr 17, 2024
1 parent 9237e96 commit c3e03cc
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 7 deletions.
2 changes: 2 additions & 0 deletions projects/fal/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"pydantic!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3",
# serve=True dependencies
"fastapi>=0.99.1,<1",
"starlette-exporter>=0.21.0",
# rest-api-client dependencies
"httpx>=0.15.4",
"attrs>=21.3.0",
Expand All @@ -52,6 +53,7 @@ dependencies = [
"websockets>=12.0,<13",
"pillow>=10.2.0,<11",
"pyjwt[crypto]>=2.8.0,<3",
"uvicorn>=0.29.0,<1",
]

[project.optional-dependencies]
Expand Down
79 changes: 72 additions & 7 deletions projects/fal/src/fal/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import inspect
import socket
import sys
import threading
from collections import defaultdict
Expand All @@ -27,8 +29,10 @@
import grpc
import isolate
import tblib
import uvicorn
import yaml
from fastapi import FastAPI, __version__ as fastapi_version
from fastapi import FastAPI
from fastapi import __version__ as fastapi_version
from isolate.backends.common import active_python
from isolate.backends.settings import DEFAULT_SETTINGS
from isolate.connections import PythonIPC
Expand All @@ -37,7 +41,7 @@
from typing_extensions import Concatenate, ParamSpec

import fal.flags as flags
from fal._serialization import patch_pickle, include_modules_from
from fal._serialization import include_modules_from, patch_pickle
from fal.logging.isolate import IsolateLogPrinter
from fal.sdk import (
FAL_SERVERLESS_DEFAULT_KEEP_ALIVE,
Expand All @@ -58,7 +62,7 @@
BasicConfig = Dict[str, Any]
_UNSET = object()

SERVE_REQUIREMENTS = [f"fastapi=={fastapi_version}", "uvicorn"]
SERVE_REQUIREMENTS = [f"fastapi=={fastapi_version}", "uvicorn", "starlette_exporter"]


@dataclass
Expand Down Expand Up @@ -212,7 +216,10 @@ class LocalHost(Host):
# packages for isolate agent to run.
_AGENT_ENVIRONMENT = isolate.prepare_environment(
"virtualenv",
requirements=[f"cloudpickle=={cloudpickle.__version__}", f"tblib=={tblib.__version__}"],
requirements=[
f"cloudpickle=={cloudpickle.__version__}",
f"tblib=={tblib.__version__}",
],
)
_log_printer = IsolateLogPrinter(debug=flags.DEBUG)

Expand All @@ -223,7 +230,11 @@ def run(
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> ReturnT:
settings = replace(DEFAULT_SETTINGS, serialization_method="cloudpickle", log_hook=self._log_printer.print)
settings = replace(
DEFAULT_SETTINGS,
serialization_method="cloudpickle",
log_hook=self._log_printer.print,
)
environment = isolate.prepare_environment(
**options.environment,
context=settings,
Expand Down Expand Up @@ -422,6 +433,8 @@ def register(
if partial_result.result:
return partial_result.result.application_id

return None

@_handle_grpc_error()
def run(
self,
Expand Down Expand Up @@ -818,6 +831,7 @@ def _build_app(self) -> FastAPI:
from fastapi import HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette_exporter import PrometheusMiddleware

_app = FalFastAPI(lifespan=self.lifespan)

Expand All @@ -828,6 +842,13 @@ def _build_app(self) -> FastAPI:
allow_methods=("*"),
allow_origins=("*"),
)
_app.add_middleware(
PrometheusMiddleware,
prefix="http",
group_paths=True,
filter_unhandled_paths=True,
app_name="fal",
)

self._add_extra_middlewares(_app)

Expand Down Expand Up @@ -873,10 +894,35 @@ def openapi(self) -> dict[str, Any]:
return self._build_app().openapi()

def serve(self) -> None:
import uvicorn
import asyncio

from starlette_exporter import handle_metrics
from uvicorn import Config

app = self._build_app()
uvicorn.run(app, host="0.0.0.0", port=8080)
server = Server(config=Config(app, host="0.0.0.0", port=8080))
metrics_app = FastAPI()
metrics_app.add_route("/metrics", handle_metrics)
metrics_server = Server(config=Config(metrics_app, host="0.0.0.0", port=9090))

async def _serve() -> None:
event = asyncio.Event()
# TODO(squat): handle shutdowns gracefully.
# You cannot add signal handlers to any loop if you're not
# on the main thread.
# How can we detect that we are being shut down and stop the
# uvicorn servers gracefully?
# loop = asyncio.get_running_loop()
# loop.add_signal_handler(signal.SIGINT, event.set)
# loop.add_signal_handler(signal.SIGTERM, event.set)
await asyncio.gather(
server.serve_until_event(event),
metrics_server.serve_until_event(event),
)

with suppress(asyncio.CancelledError):
asyncio.set_event_loop(asyncio.new_event_loop())
asyncio.run(_serve())


class ServeWrapper(BaseServable):
Expand Down Expand Up @@ -1035,3 +1081,22 @@ def on(
self, host: Host | None = None, *, serve: Literal[False], **config: Any
) -> IsolatedFunction[ArgsT, ReturnT]:
...


class Server(uvicorn.Server):
"""Server is a uvicorn.Server that actually plays nicely with signals.
By default, uvicorn's Server class overwrites the signal handler for SIGINT, swallowing the signal and preventing other tasks from cancelling.
This class allows the task to be gracefully cancelled using asyncio's built-in task cancellation or with an event, like aiohttp.
"""

def install_signal_handlers(self) -> None:
pass

async def serve_until_event(
self, finish_event: asyncio.Event, sockets: list[socket.socket] | None = None
) -> None:
serve = asyncio.create_task(super().serve(sockets))
await finish_event.wait()
self.should_exit = True
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(serve, timeout=10)

0 comments on commit c3e03cc

Please sign in to comment.