Skip to content

Commit

Permalink
feat: send request id data to isolate logger in apps
Browse files Browse the repository at this point in the history
  • Loading branch information
mederka committed Oct 15, 2024
1 parent 4ba761b commit 1428edd
Showing 1 changed file with 71 additions and 3 deletions.
74 changes: 71 additions & 3 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import inspect
import json
import os
Expand All @@ -8,12 +9,14 @@
import threading
import time
import typing
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, ClassVar, Literal, TypeVar

import grpc.aio as async_grpc
import httpx
from fastapi import FastAPI
from isolate.server import definitions

import fal.api
from fal._serialization import include_modules_from
Expand All @@ -24,6 +27,7 @@
from fal.toolkit.file.providers.fal import GLOBAL_LIFECYCLE_PREFERENCE

REALTIME_APP_REQUIREMENTS = ["websockets", "msgpack"]
REQUEST_ID_KEY = "x-fal-request-id"

EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
logger = get_logger(__name__)
Expand All @@ -36,6 +40,61 @@ async def _call_any_fn(fn, *args, **kwargs):
return fn(*args, **kwargs)


@dataclass
class IsolateChannel:
address: str
_stack: AsyncExitStack | None = field(repr=False, default=None)

async def __aenter__(self) -> async_grpc.Channel:
self._stack = AsyncExitStack()
channel = await self._stack.enter_async_context(
async_grpc.insecure_channel(
self.address,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
("grpc.min_reconnect_backoff_ms", 0),
("grpc.max_reconnect_backoff_ms", 100),
("grpc.dns_min_time_between_resolutions_ms", 100),
],
)
)
channel_status = channel.channel_ready()
try:
await asyncio.wait_for(channel_status, timeout=60)
except asyncio.TimeoutError:
raise Exception("Timed out trying to connect to local isolate")
else:
return channel

async def __aexit__(self, *args):
if self._stack is None:
return None

try:
await self._stack.aclose()
finally:
self._client = None
self._stack = None


async def _set_logger_labels(logger_labels: dict[str, str]):
grpc_port = os.environ.get("NOMAD_ALLOC_PORT_grpc")

try:
async with IsolateChannel(f"localhost:{grpc_port}") as channel:
isolate = definitions.IsolateStub(channel)
isolate_request = definitions.SetMetadataRequest(
task_id="RUN",
metadata=definitions.TaskMetadata(logger_labels=logger_labels),
)
res = isolate.SetMetadata(isolate_request)
code = await res.code()
assert str(code) == "StatusCode.OK"
except Exception:
logger.exception("Failed to set logger labels")


def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
include_modules_from(cls)

Expand Down Expand Up @@ -300,7 +359,16 @@ async def set_global_object_preference(request, call_next):
"Failed set a global lifecycle preference %s",
self.__class__.__name__,
)
return await call_next(request)

request_id = request.headers.get(REQUEST_ID_KEY)
if request_id is not None:
await _set_logger_labels({"request_id": request_id})

response = await call_next(request)

await _set_logger_labels({})

return response

@app.exception_handler(RequestCancelledException)
async def value_error_exception_handler(
Expand Down

0 comments on commit 1428edd

Please sign in to comment.