Skip to content

Commit

Permalink
Prevent OOM when receiving large request streams
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlar committed Sep 26, 2024
1 parent 502331c commit f1e6350
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 27 deletions.
39 changes: 22 additions & 17 deletions zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,28 @@ private[netty] abstract class AsyncBodyReader extends SimpleChannelInboundHandle
onLastMessage()
}

state match {
case State.Buffering =>
// `connect` method hasn't been called yet, add all incoming content to the buffer
buffer0.addAll(content)
case State.Direct(callback) if isLast && buffer0.knownSize == 0 =>
// Buffer is empty, we can just use the array directly
callback(Chunk.fromArray(content), isLast = true)
case State.Direct(callback: UnsafeAsync.Aggregating) =>
// We're aggregating the full response, only call the callback on the last message
buffer0.addAll(content)
if (isLast) callback(result(buffer0), isLast = true)
case State.Direct(callback) =>
// We're streaming, emit chunks as they come
callback(Chunk.fromArray(content), isLast)
}

if (!isLast) ctx.read(): Unit
val streaming =
state match {
case State.Buffering =>
// `connect` method hasn't been called yet, add all incoming content to the buffer
buffer0.addAll(content)
false
case State.Direct(callback) if isLast && buffer0.knownSize == 0 =>
// Buffer is empty, we can just use the array directly
callback(Chunk.fromArray(content), isLast = true)
false
case State.Direct(callback: UnsafeAsync.Aggregating) =>
// We're aggregating the full response, only call the callback on the last message
buffer0.addAll(content)
if (isLast) callback(result(buffer0), isLast = true)
false
case State.Direct(callback) =>
// We're streaming, emit chunks as they come
callback(Chunk.fromArray(content), isLast)
true
}

if (!isLast && !streaming) ctx.read(): Unit
}
}

Expand Down
24 changes: 16 additions & 8 deletions zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ object NettyBody extends BodyEncoding {
unsafeAsync: UnsafeAsync => Unit,
knownContentLength: Option[Long],
contentTypeHeader: Option[Header.ContentType] = None,
readMore: () => Unit = () => (),
): Body = {
AsyncBody(
unsafeAsync,
knownContentLength,
contentTypeHeader.map(Body.ContentType.fromHeader),
readMore,
)
}

Expand Down Expand Up @@ -92,6 +94,7 @@ object NettyBody extends BodyEncoding {
unsafeAsync: UnsafeAsync => Unit,
knownContentLength: Option[Long],
override val contentType: Option[Body.ContentType] = None,
nettyRead: () => Unit,
) extends Body {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = asChunk.map {
Expand All @@ -110,12 +113,14 @@ object NettyBody extends BodyEncoding {
}

override def asStream(implicit trace: Trace): ZStream[Any, Throwable, Byte] = {
asyncUnboundedStream[Any, Throwable, Byte](emit =>
try {
unsafeAsync(new UnsafeAsync.Streaming(emit))
} catch {
case e: Throwable => emit(ZIO.fail(Option(e)))
},
asyncUnboundedStream[Any, Throwable, Byte](
emit =>
try {
unsafeAsync(new UnsafeAsync.Streaming(emit))
} catch {
case e: Throwable => emit(ZIO.fail(Option(e)))
},
ZIO.succeed(nettyRead()),
)
}

Expand All @@ -137,10 +142,13 @@ object NettyBody extends BodyEncoding {
}

/**
* Code ported from zio.stream to use an unbounded queue
* Code ported from zio.stream to use an unbounded queue On top of that the
* nettyRead() function is added. It is used to call netty ctx.read() when the
* queue is empty
*/
private def asyncUnboundedStream[R, E, A](
register: ZStream.Emit[R, E, A, Unit] => Unit,
nettyRead: UIO[Unit],
)(implicit trace: Trace): ZStream[R, E, A] =
ZStream.unwrapScoped[R](for {
queue <- ZIO.acquireRelease(Queue.unbounded[Take[E, A]])(_.shutdown)
Expand All @@ -166,7 +174,7 @@ object NettyBody extends BodyEncoding {
maybeError =>
ZChannel.fromZIO(queue.shutdown) *>
maybeError.fold[ZChannel[Any, Any, Any, Any, E, Chunk[A], Unit]](ZChannel.unit)(ZChannel.fail(_)),
a => ZChannel.write(a) *> loop,
a => ZChannel.write(a) *> ZChannel.fromZIO(nettyRead.whenZIO(queue.isEmpty)) *> loop,
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private[netty] object NettyBodyWriter {
val stream = ZStream.fromFile(body.file)
val s = StreamBody(stream, None, contentType = body.contentType)
NettyBodyWriter.writeAndFlush(s, None, ctx)
case AsyncBody(async, _, _) =>
case AsyncBody(async, _, _, _) =>
async(
new UnsafeAsync {
override def apply(message: Chunk[Byte], isLast: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private[netty] object NettyResponse {
callback => responseHandler.connect(callback),
knownContentLength,
contentType,
() => ctx.read(): Unit,
)
Response(status, headers, data)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ private[zio] final case class ServerInboundHandler(
case nettyReq: HttpRequest =>
val knownContentLength = headers.get(Header.ContentLength).map(_.length)
val handler = addAsyncBodyHandler(ctx)
val body = NettyBody.fromAsync(async => handler.connect(async), knownContentLength, contentTypeHeader)
val body = NettyBody.fromAsync(
async => handler.connect(async),
knownContentLength,
contentTypeHeader,
() => ctx.read(): Unit,
)

Request(
body = body,
Expand Down

0 comments on commit f1e6350

Please sign in to comment.