Skip to content

Commit

Permalink
fix(grpc): fix segfault with grpc.aio streaming responses [backport 1…
Browse files Browse the repository at this point in the history
….20] (#9276)

Backport
5897cab
from #9233 to 1.20.

This PR fixes a few issues with the grpc aio integration. Most notably,
the integration was causing segfaults when wrapping async stream
responses, most likely since these spans were never being finished. This
issue was uncovered when customers upgraded their google-api-core
dependencies to 2.17.0; with this upgrade, the package changed many grpc
calls to use async streaming. In addition to fixing the segfault, this
PR also fixes the Pin object to be correctly placed on the grpcio
module.

Fixes #9139

## Checklist

- [x] Change(s) are motivated and described in the PR description
- [x] Testing strategy is described if automated tests are not included
in the PR
- [x] Risks are described (performance impact, potential for breakage,
maintainability)
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] [Library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
are followed or label `changelog/no-changelog` is set
- [x] Documentation is included (in-code, generated user docs, [public
corp docs](https://github.com/DataDog/documentation/))
- [x] Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))
- [x] If this PR changes the public interface, I've notified
`@DataDog/apm-tees`.

## Reviewer Checklist

- [x] Title is accurate
- [x] All changes are related to the pull request's stated goal
- [x] Description motivates each change
- [x] Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- [x] Testing strategy adequately addresses listed risks
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] Release note makes sense to a user of the library
- [x] Author has acknowledged and discussed the performance implications
of this PR as reported in the benchmarks PR comment
- [x] Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)

---------

Co-authored-by: Emmett Butler <[email protected]>
  • Loading branch information
wconti27 and emmettbutler authored May 17, 2024
1 parent 7a0e539 commit a979c73
Show file tree
Hide file tree
Showing 9 changed files with 425 additions and 25 deletions.
64 changes: 54 additions & 10 deletions ddtrace/contrib/grpc/aio_client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
from ...ext import SpanKind
from ...ext import SpanTypes
from ...internal.compat import to_unicode
from ...internal.logger import get_logger
from ...propagation.http import HTTPPropagator
from ..grpc import constants
from ..grpc import utils


log = get_logger(__name__)


def create_aio_client_interceptors(pin, host, port):
# type: (Pin, str, int) -> Tuple[aio.ClientInterceptor, ...]
return (
Expand All @@ -42,7 +46,17 @@ def create_aio_client_interceptors(pin, host, port):
)


def _done_callback(span, code, details):
def _handle_add_callback(call, callback):
try:
call.add_done_callback(callback)
except NotImplementedError:
# add_done_callback is not implemented in UnaryUnaryCallResponse
# https://github.com/grpc/grpc/blob/c54c69dcdd483eba78ed8dbc98c60a8c2d069758/src/python/grpcio/grpc/aio/_interceptor.py#L1058
# If callback is not called, we need to finish the span here
callback(call)


def _done_callback_unary(span, code, details):
# type: (Span, grpc.StatusCode, str) -> Callable[[aio.Call], None]
def func(call):
# type: (aio.Call) -> None
Expand All @@ -51,15 +65,45 @@ def func(call):

# Handle server-side error in unary response RPCs
if code != grpc.StatusCode.OK:
_handle_error(span, call, code, details)
_handle_error(span, code, details)
finally:
span.finish()

return func


def _done_callback_stream(span):
# type: (Span) -> Callable[[aio.Call], None]
def func(call):
# type: (aio.Call) -> None
try:
if call.done():
# check to ensure code and details are not already set, in which case this span
# is an error span and already has all error tags from `_handle_cancelled_error`
code_tag = span.get_tag(constants.GRPC_STATUS_CODE_KEY)
details_tag = span.get_tag(ERROR_MSG)
if not code_tag or not details_tag:
# we need to call __repr__ as we cannot call code() or details() since they are both async
code, details = utils._parse_rpc_repr_string(call.__repr__(), grpc)

span.set_tag_str(constants.GRPC_STATUS_CODE_KEY, to_unicode(code))

# Handle server-side error in unary response RPCs
if code != grpc.StatusCode.OK:
_handle_error(span, code, details)
else:
log.warning("Grpc call has not completed, unable to set status code and details on span.")
except ValueError:
# ValueError is thrown from _parse_rpc_repr_string
log.warning("Unable to parse async grpc string for status code and details.")
finally:
span.finish()

return func


def _handle_error(span, call, code, details):
# type: (Span, aio.Call, grpc.StatusCode, str) -> None
def _handle_error(span, code, details):
# type: (Span, grpc.StatusCode, str) -> None
span.error = 1
span.set_tag_str(ERROR_MSG, details)
span.set_tag_str(ERROR_TYPE, to_unicode(code))
Expand Down Expand Up @@ -152,13 +196,13 @@ async def _wrap_stream_response(
):
# type: (...) -> ResponseIterableType
try:
_handle_add_callback(call, _done_callback_stream(span))
async for response in call:
yield response
code = await call.code()
details = await call.details()
# NOTE: The callback is registered after the iteration is done,
# otherwise `call.code()` and `call.details()` block indefinitely.
call.add_done_callback(_done_callback(span, code, details))
except StopAsyncIteration:
# Callback will handle span finishing
_handle_cancelled_error(call, span)
raise
except aio.AioRpcError as rpc_error:
# NOTE: We can also handle the error in done callbacks,
# but reuse this error handling function used in unary response RPCs.
Expand All @@ -184,7 +228,7 @@ async def _wrap_unary_response(
# NOTE: As both `code` and `details` are available after the RPC is done (= we get `call` object),
# and we can't call awaitable functions inside the non-async callback,
# there is no other way but to register the callback here.
call.add_done_callback(_done_callback(span, code, details))
_handle_add_callback(call, _done_callback_unary(span, code, details))
return call
except aio.AioRpcError as rpc_error:
# NOTE: `AioRpcError` is raised in `await continuation(...)`
Expand Down
7 changes: 3 additions & 4 deletions ddtrace/contrib/grpc/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _unpatch_aio_server():
def _client_channel_interceptor(wrapped, instance, args, kwargs):
channel = wrapped(*args, **kwargs)

pin = Pin.get_from(channel)
pin = Pin.get_from(constants.GRPC_PIN_MODULE_CLIENT)
if not pin or not pin.enabled():
return channel

Expand All @@ -207,11 +207,10 @@ def _client_channel_interceptor(wrapped, instance, args, kwargs):


def _aio_client_channel_interceptor(wrapped, instance, args, kwargs):
channel = wrapped(*args, **kwargs)
pin = Pin.get_from(GRPC_AIO_PIN_MODULE_CLIENT)

pin = Pin.get_from(channel)
if not pin or not pin.enabled():
return channel
return wrapped(*args, **kwargs)

(host, port) = utils._parse_target_from_args(args, kwargs)

Expand Down
28 changes: 28 additions & 0 deletions ddtrace/contrib/grpc/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re

from ddtrace.internal.compat import parse

Expand Down Expand Up @@ -74,3 +75,30 @@ def _parse_target_from_args(args, kwargs):
return hostname, port
except ValueError:
log.warning("Malformed target '%s'.", target)


def _parse_rpc_repr_string(rpc_string, module):
# Define the regular expression patterns to extract status and details
status_pattern = r"status\s*=\s*StatusCode\.(\w+)"
details_pattern = r'details\s*=\s*"([^"]*)"'

# Search for the status and details in the input string
status_match = re.search(status_pattern, rpc_string)
details_match = re.search(details_pattern, rpc_string)

if not status_match or not details_match:
raise ValueError("Unable to parse grpc status or details repr string")

# Extract the status and details from the matches
status_str = status_match.group(1)
details = details_match.group(1)

# Convert the status string to a grpc.StatusCode object
try:
code = module.StatusCode[status_str]
except KeyError:
code = None
raise ValueError("Invalid grpc status code: " + status_str)

# Return the status code and details
return code, details
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ MySQL
OpenTracing
Runtimes
SpanContext
aio
aiobotocore
aiohttp
aiomysql
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
fix(grpc): This change fixes a bug in the grpc.aio support specific to streaming responses.
36 changes: 36 additions & 0 deletions tests/contrib/grpc_aio/hellostreamingworld_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/contrib/grpc_aio/hellostreamingworld_pb2.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# isort: off
from typing import ClassVar as _ClassVar
from typing import Optional as _Optional

from ddtrace.internal.compat import PYTHON_VERSION_INFO

if PYTHON_VERSION_INFO > (3, 7):
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message

DESCRIPTOR: _descriptor.FileDescriptor
class HelloReply(_message.Message):
__slots__ = ["message"]
MESSAGE_FIELD_NUMBER: _ClassVar[int]
message: str
def __init__(self, message: _Optional[str] = ...) -> None: ...
class HelloRequest(_message.Message):
__slots__ = ["name", "num_greetings"]
NAME_FIELD_NUMBER: _ClassVar[int]
NUM_GREETINGS_FIELD_NUMBER: _ClassVar[int]
name: str
num_greetings: str
def __init__(self, name: _Optional[str] = ..., num_greetings: _Optional[str] = ...) -> None: ...
77 changes: 77 additions & 0 deletions tests/contrib/grpc_aio/hellostreamingworld_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

from ddtrace.internal.compat import PYTHON_VERSION_INFO


if PYTHON_VERSION_INFO > (3, 7):
from tests.contrib.grpc_aio import hellostreamingworld_pb2 as hellostreamingworld__pb2

class MultiGreeterStub(object):
"""The greeting service definition."""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.sayHello = channel.unary_stream(
"/hellostreamingworld.MultiGreeter/sayHello",
request_serializer=hellostreamingworld__pb2.HelloRequest.SerializeToString,
response_deserializer=hellostreamingworld__pb2.HelloReply.FromString,
)

class MultiGreeterServicer(object):
"""The greeting service definition."""

def sayHello(self, request, context):
"""Sends multiple greetings"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def add_MultiGreeterServicer_to_server(servicer, server):
rpc_method_handlers = {
"sayHello": grpc.unary_stream_rpc_method_handler(
servicer.sayHello,
request_deserializer=hellostreamingworld__pb2.HelloRequest.FromString,
response_serializer=hellostreamingworld__pb2.HelloReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler("hellostreamingworld.MultiGreeter", rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))

# This class is part of an EXPERIMENTAL API.
class MultiGreeter(object):
"""The greeting service definition."""

@staticmethod
def sayHello(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.unary_stream(
request,
target,
"/hellostreamingworld.MultiGreeter/sayHello",
hellostreamingworld__pb2.HelloRequest.SerializeToString,
hellostreamingworld__pb2.HelloReply.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
Loading

0 comments on commit a979c73

Please sign in to comment.