Skip to content

Commit

Permalink
Expose config option for avoiding context switching when handling req…
Browse files Browse the repository at this point in the history
…uests (#2944)
  • Loading branch information
kyri-petrou authored Aug 14, 2024
1 parent 1c939b3 commit 5294e91
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 52 deletions.
7 changes: 5 additions & 2 deletions zio-http/jvm/src/main/scala/zio/http/netty/NettyConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ final case class NettyConfig(
def channelType(channelType: ChannelType): NettyConfig = self.copy(channelType = channelType)

/**
* Configure the server to use the leak detection level provided (@see <a
* href="https://netty.io/4.1/api/io/netty/util/ResourceLeakDetector.Level.html">ResourceLeakDetector.Level</a>).
* Configure the server to use the leak detection level provided.
*
* @see
* <a
* href="https://netty.io/4.1/api/io/netty/util/ResourceLeakDetector.Level.html">ResourceLeakDetector.Level</a>
*/
def leakDetection(level: LeakDetectionLevel): NettyConfig = self.copy(leakDetectionLevel = level)

Expand Down
72 changes: 33 additions & 39 deletions zio-http/jvm/src/main/scala/zio/http/netty/NettyRuntime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,55 +23,49 @@ import io.netty.channel._
import io.netty.util.concurrent.{Future, GenericFutureListener}

private[zio] final class NettyRuntime(zioRuntime: Runtime[Any]) {
private[this] val rtm = zioRuntime.unsafe

private val rtm = zioRuntime.unsafe

def run(ctx: ChannelHandlerContext, ensured: () => Unit, interruptOnClose: Boolean = true)(
def run(
ctx: ChannelHandlerContext,
ensured: () => Unit,
preferOnCurrentThread: Boolean,
)(
program: ZIO[Any, Throwable, Any],
)(implicit unsafe: Unsafe, trace: Trace): Unit = {

def onFailure(cause: Cause[Throwable], ctx: ChannelHandlerContext): Unit = {
cause.failureOption.orElse(cause.dieOption) match {
case None => ()
case Some(error) =>
ctx.fireExceptionCaught(error)
def onExit(exit: Exit[Throwable, Any]): Unit = {
ensured()
exit match {
case Exit.Success(_) =>
case Exit.Failure(cause) =>
cause.failureOption.orElse(cause.dieOption) match {
case None => ()
case Some(error) => ctx.fireExceptionCaught(error)
}
if (ctx.channel().isOpen) ctx.close(): Unit
}
if (ctx.channel().isOpen) ctx.close(): Unit
}

def removeListener(close: GenericFutureListener[Future[_ >: Void]]): Unit = {
if (close ne null)
ctx.channel().closeFuture().removeListener(close): Unit
}

// See https://github.com/zio/zio-http/pull/2782 on why forking is preferable over runOrFork
val fiber = rtm.fork(program)

// Close the connection if the program fails
// When connection closes, interrupt the program
val close: GenericFutureListener[Future[_ >: Void]] =
if (interruptOnClose) {
val close0 = closeListener(fiber)
ctx.channel().closeFuture.addListener(close0)
close0
} else null

fiber.unsafe.addObserver {
case Exit.Success(_) =>
removeListener(close)
ensured()
case Exit.Failure(cause) =>
removeListener(close)
onFailure(cause, ctx)
ensured()
def removeListener(close: GenericFutureListener[Future[_ >: Void]]): Unit =
ctx.channel().closeFuture().removeListener(close): Unit

val forkOrExit = if (preferOnCurrentThread) rtm.runOrFork(program) else Left(rtm.fork(program))

forkOrExit match {
case Left(fiber) =>
// Close the connection if the program fails
// When connection closes, interrupt the program
val close = closeListener(fiber)
ctx.channel().closeFuture.addListener(close)
fiber.unsafe.addObserver { exit =>
removeListener(close)
onExit(exit)
}
case Right(exit) =>
onExit(exit)
}
}

def runUninterruptible(ctx: ChannelHandlerContext, ensured: () => Unit)(
program: ZIO[Any, Throwable, Any],
)(implicit unsafe: Unsafe, trace: Trace): Unit =
run(ctx, ensured, interruptOnClose = false)(program)

@throws[Throwable]("Any errors that occur during the execution of the ZIO effect")
def unsafeRunSync[A](program: ZIO[Any, Throwable, A])(implicit unsafe: Unsafe, trace: Trace): A =
rtm.run(program).getOrThrowFiberFailure()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ final class ClientInboundHandler(
case _: HttpRequest =>
ctx.write(jReq)
NettyBodyWriter.writeAndFlush(req.body, None, ctx).foreach { effect =>
rtm.run(ctx, NettyRuntime.noopEnsuring)(effect)(Unsafe.unsafe, trace)
rtm.run(ctx, NettyRuntime.noopEnsuring, preferOnCurrentThread = true)(effect)(Unsafe.unsafe, trace)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ object NettyConnectionPool {

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
cause match {
case _: ReadTimeoutException =>
nettyRuntime.run(ctx, () => {}) { ZIO.logDebug("ReadTimeoutException caught") }
case _: ReadTimeoutException => nettyRuntime.unsafeRunSync(ZIO.logDebug("ReadTimeoutException caught"))
case _ => super.exceptionCaught(ctx, cause)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ private[zio] final case class ServerInboundHandler(

val inFlightRequests: LongAdder = new LongAdder()
private val readClientCert = config.sslConfig.exists(_.includeClientCert)
private val avoidCtxSwitching = config.avoidContextSwitching

def refreshApp(): Unit = {
val pair = appRef.get()
Expand Down Expand Up @@ -110,11 +111,11 @@ private[zio] final case class ServerInboundHandler(
(msg ne null) && msg.contains("Connection reset")
} =>
case t =>
if (runtime ne null) {
runtime.run(ctx, () => {}) {
if ((runtime ne null) && config.logWarningOnFatalError) {
runtime.unsafeRunSync {
// We cannot return the generated response from here, but still calling the handler for its side effect
// for example logging.
ZIO.logWarningCause(s"Fatal exception in Netty", Cause.die(t)).when(config.logWarningOnFatalError)
ZIO.logWarningCause(s"Fatal exception in Netty", Cause.die(t))
}
}
cause match {
Expand Down Expand Up @@ -307,7 +308,7 @@ private[zio] final case class ServerInboundHandler(
exit: ZIO[Any, Response, Response],
req: Request,
)(ensured: () => Unit): Unit = {
runtime.run(ctx, ensured) {
runtime.run(ctx, ensured, preferOnCurrentThread = avoidCtxSwitching) {
exit.sandbox.catchAll { error =>
error.failureOrCause
.fold[UIO[Response]](
Expand Down Expand Up @@ -365,7 +366,6 @@ object ServerInboundHandler {
for {
appRef <- ZIO.service[AppRef]
config <- ZIO.service[Server.Config]

} yield ServerInboundHandler(appRef, config)
}
}
Expand Down
28 changes: 25 additions & 3 deletions zio-http/shared/src/main/scala/zio/http/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,33 @@ object Server extends ServerPlatformSpecific {
gracefulShutdownTimeout: Duration,
webSocketConfig: WebSocketConfig,
idleTimeout: Option[Duration],
) {
self =>
avoidContextSwitching: Boolean,
) { self =>

/**
* Configure the server to use HttpServerExpectContinueHandler to send a 100
* HttpResponse if necessary.
*/
def acceptContinue(enable: Boolean): Config = self.copy(acceptContinue = enable)

/**
* Attempts to avoid context switching between thread pools by executing
* requests within the server's IO thread-pool (e.g., Netty's EventLoop)
* until the first async/blocking boundary.
*
* Enabling this setting can improve performance for short-lived CPU-bound
* tasks, but can also lead to degraded performance if the request handler
* performs CPU-heavy work prior to the first async boundary.
*
* '''WARNING:''' Do not use this mode if the ZIO executor is configured to
* use virtual threads!
*
* @see
* For more info on caveats of this mode, see <a
* href="https://github.com/zio/zio-http/pull/2782">related issue </a>
*/
def avoidContextSwitching(value: Boolean): Config = self.copy(avoidContextSwitching = value)

/**
* Configure the server to listen on the provided hostname and port.
*/
Expand Down Expand Up @@ -173,7 +191,8 @@ object Server extends ServerPlatformSpecific {
zio.Config.int("max-header-size").withDefault(Config.default.maxHeaderSize) ++
zio.Config.boolean("log-warning-on-fatal-error").withDefault(Config.default.logWarningOnFatalError) ++
zio.Config.duration("graceful-shutdown-timeout").withDefault(Config.default.gracefulShutdownTimeout) ++
zio.Config.duration("idle-timeout").optional.withDefault(Config.default.idleTimeout)
zio.Config.duration("idle-timeout").optional.withDefault(Config.default.idleTimeout) ++
zio.Config.boolean("avoid-context-switching").withDefault(Config.default.avoidContextSwitching)
}.map {
case (
sslConfig,
Expand All @@ -189,6 +208,7 @@ object Server extends ServerPlatformSpecific {
logWarningOnFatalError,
gracefulShutdownTimeout,
idleTimeout,
avoidCtxSwitch,
) =>
default.copy(
sslConfig = sslConfig,
Expand All @@ -203,6 +223,7 @@ object Server extends ServerPlatformSpecific {
logWarningOnFatalError = logWarningOnFatalError,
gracefulShutdownTimeout = gracefulShutdownTimeout,
idleTimeout = idleTimeout,
avoidContextSwitching = avoidCtxSwitch,
)
}

Expand All @@ -220,6 +241,7 @@ object Server extends ServerPlatformSpecific {
gracefulShutdownTimeout = 10.seconds,
webSocketConfig = WebSocketConfig.default,
idleTimeout = None,
avoidContextSwitching = false,
)

final case class ResponseCompressionConfig(
Expand Down

0 comments on commit 5294e91

Please sign in to comment.