Skip to content

Commit

Permalink
pr comments 2
Browse files Browse the repository at this point in the history
  • Loading branch information
mederka committed Oct 15, 2024
1 parent 45a98bb commit cb4fa33
Showing 1 changed file with 43 additions and 79 deletions.
122 changes: 43 additions & 79 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import time
import typing
from contextlib import AsyncExitStack, asynccontextmanager, contextmanager
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import cache
from typing import Any, Callable, ClassVar, Literal, TypeVar

import grpc.aio as async_grpc
Expand All @@ -31,7 +32,6 @@

EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
logger = get_logger(__name__)
GRPC_PORT = os.environ.get("NOMAD_ALLOC_PORT_grpc")


async def _call_any_fn(fn, *args, **kwargs):
Expand All @@ -41,67 +41,50 @@ async def _call_any_fn(fn, *args, **kwargs):
return fn(*args, **kwargs)


@dataclass
class IsolateChannel:
address: str
_stack: AsyncExitStack | None = field(repr=False, default=None)

async def __aenter__(self) -> async_grpc.Channel:
self._stack = AsyncExitStack()
channel = await self._stack.enter_async_context(
async_grpc.insecure_channel(
self.address,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
("grpc.min_reconnect_backoff_ms", 0),
("grpc.max_reconnect_backoff_ms", 100),
("grpc.dns_min_time_between_resolutions_ms", 100),
],
)
)
channel_status = channel.channel_ready()
try:
await asyncio.wait_for(channel_status, timeout=60)
except asyncio.TimeoutError:
raise Exception("Timed out trying to connect to local isolate")
else:
return channel
@cache
def get_grpc_port() -> str | None:
return os.environ.get("NOMAD_ALLOC_PORT_grpc")

async def __aexit__(self, *args):
if self._stack is None:
return None

try:
await self._stack.aclose()
finally:
self._client = None
self._stack = None
async def open_isolate_channel(address: str) -> async_grpc.Channel:
_stack = AsyncExitStack()
channel = await _stack.enter_async_context(
async_grpc.insecure_channel(
address,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
("grpc.min_reconnect_backoff_ms", 0),
("grpc.max_reconnect_backoff_ms", 100),
("grpc.dns_min_time_between_resolutions_ms", 100),
],
)
)

channel_status = channel.channel_ready()
try:
await asyncio.wait_for(channel_status, timeout=60)
except asyncio.TimeoutError:
await _stack.aclose()
raise Exception("Timed out trying to connect to local isolate")

@asynccontextmanager
async def assure_isolate_channel(channel: typing.Optional[IsolateChannel] = None):
if channel is None:
async with IsolateChannel(f"localhost:{GRPC_PORT}") as new_channel:
yield new_channel
else:
yield channel
return channel


async def _set_logger_labels(
logger_labels: dict[str, str], channel: typing.Optional[IsolateChannel] = None
logger_labels: dict[str, str], channel: async_grpc.Channel
):
try:
async with assure_isolate_channel(channel) as assured_channel:
isolate = definitions.IsolateStub(assured_channel)
isolate_request = definitions.SetMetadataRequest(
task_id="RUN",
metadata=definitions.TaskMetadata(logger_labels=logger_labels),
)
res = isolate.SetMetadata(isolate_request)
code = await res.code()
assert str(code) == "StatusCode.OK"
except Exception:
isolate = definitions.IsolateStub(channel)
isolate_request = definitions.SetMetadataRequest(
# TODO: when submit is shipped, get task_id from an env var
task_id="RUN",
metadata=definitions.TaskMetadata(logger_labels=logger_labels),
)
res = isolate.SetMetadata(isolate_request)
code = await res.code()
assert str(code) == "StatusCode.OK"
except BaseException:
logger.exception("Failed to set logger labels")


Expand Down Expand Up @@ -279,7 +262,7 @@ class App(fal.api.BaseServable):
app_auth: ClassVar[Literal["private", "public", "shared"]] = "private"
request_timeout: ClassVar[int | None] = None

isolate_channel: IsolateChannel | None = None
isolate_channel: async_grpc.Channel | None = None

def __init_subclass__(cls, **kwargs):
app_name = kwargs.pop("name", None) or _to_fal_app_name(cls.__name__)
Expand Down Expand Up @@ -372,44 +355,25 @@ async def set_global_object_preference(request, call_next):
self.__class__.__name__,
)

request_id = request.headers.get(REQUEST_ID_KEY)
if request_id is not None:
await _set_logger_labels({"fal_request_id": request_id})

response = await call_next(request)

await _set_logger_labels({})

return response
return await call_next(request)

@app.middleware("http")
async def set_request_id(request, call_next):
if self.isolate_channel is None:
async with assure_isolate_channel() as channel:
self.isolate_channel = channel
self.isolate_channel = await open_isolate_channel(
f"localhost:{get_grpc_port()}"
)

request_id = request.headers.get(REQUEST_ID_KEY)
if request_id is not None:
await _set_logger_labels(
{"fal_request_id": request_id}, channel=self.isolate_channel
)

response = None
exception = None
try:
response = await call_next(request)
except BaseException as e:
exception = e
return await call_next(request)
finally:
await _set_logger_labels({}, channel=self.isolate_channel)

if response is not None:
return response
elif exception is not None:
raise exception

raise Exception("Both response and exception are None.")

@app.exception_handler(RequestCancelledException)
async def value_error_exception_handler(
request, exc: RequestCancelledException
Expand Down

0 comments on commit cb4fa33

Please sign in to comment.