From fdfde8e73a3bca1e875a1ae96bb9519d68cd8ebf Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Tue, 30 Jul 2024 21:22:54 +0200 Subject: [PATCH] Encode SSE based on HttpContentCodec (#2695) (#2951) --- .../main/scala/zio/http/ServerSentEvent.scala | 23 ++++++++++++- .../http/codec/internal/EncoderDecoder.scala | 34 ++++--------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/zio-http/shared/src/main/scala/zio/http/ServerSentEvent.scala b/zio-http/shared/src/main/scala/zio/http/ServerSentEvent.scala index 743926a818..cf1661d785 100644 --- a/zio-http/shared/src/main/scala/zio/http/ServerSentEvent.scala +++ b/zio-http/shared/src/main/scala/zio/http/ServerSentEvent.scala @@ -16,10 +16,15 @@ package zio.http -import zio.stacktracer.TracingImplicits.disableAutoTrace +import zio._ +import zio.stream.ZPipeline + +import zio.schema.codec.{BinaryCodec, DecodeError} import zio.schema.{DeriveSchema, Schema} +import zio.http.codec.{BinaryCodecWithSchema, HttpContentCodec} + /** * Server-Sent Event (SSE) as defined by * https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events @@ -61,5 +66,21 @@ final case class ServerSentEvent( object ServerSentEvent { implicit lazy val schema: Schema[ServerSentEvent] = DeriveSchema.gen[ServerSentEvent] + implicit val contentCodec: HttpContentCodec[ServerSentEvent] = HttpContentCodec.from( + MediaType.text.`event-stream` -> BinaryCodecWithSchema.fromBinaryCodec(new BinaryCodec[ServerSentEvent] { + override def decode(whole: Chunk[Byte]): Either[DecodeError, ServerSentEvent] = + throw new UnsupportedOperationException("ServerSentEvent decoding is not yet supported.") + + override def streamDecoder: ZPipeline[Any, DecodeError, Byte, ServerSentEvent] = + throw new UnsupportedOperationException("ServerSentEvent decoding is not yet supported.") + + override def encode(value: ServerSentEvent): Chunk[Byte] = + Chunk.fromArray(value.encode.getBytes) + + override def streamEncoder: ZPipeline[Any, Nothing, ServerSentEvent, Byte] = + ZPipeline.mapChunks(value => value.flatMap(c => c.encode.getBytes)) + }), + ) + def heartbeat: ServerSentEvent = new ServerSentEvent("") } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index b4ddc3e2ee..7663d2461c 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -216,11 +216,6 @@ private[codec] object EncoderDecoder { } else { false } - private val isEventStream = if (flattened.content.length == 1) { - isEventStreamBody(flattened.content(0)) - } else { - false - } private val onlyTheLastFieldIsStreaming = if (flattened.content.size > 1) { !flattened.content.init.exists(isByteStreamBody) && isByteStreamBody(flattened.content.last) @@ -533,26 +528,20 @@ private[codec] object EncoderDecoder { case SimpleCodec.Specified(method) => Some(method) } } else None - private def encodeBody(inputs: Array[Any], outputTypes: Chunk[MediaTypeWithQFactor]): Body = { + private def encodeBody(inputs: Array[Any], outputTypes: Chunk[MediaTypeWithQFactor]): Body = if (isByteStream) { Body.fromStreamChunked(inputs(0).asInstanceOf[ZStream[Any, Nothing, Byte]]) } else { - if (inputs.length > 1) { - Body.fromMultipartForm(encodeMultipartFormData(inputs, outputTypes), formBoundary) - } else { - if (isEventStream) { - Body.fromCharSequenceStreamChunked( - inputs(0).asInstanceOf[ZStream[Any, Nothing, ServerSentEvent]].map(_.encode), - ) - } else if (inputs.length < 1) { + inputs.length match { + case 0 => Body.empty - } else { + case 1 => val bodyCodec = flattened.content(0) bodyCodec.erase.encodeToBody(inputs(0), outputTypes) - } + case _ => + Body.fromMultipartForm(encodeMultipartFormData(inputs, outputTypes), formBoundary) } } - } private def encodeMultipartFormData(inputs: Array[Any], outputTypes: Chunk[MediaTypeWithQFactor]): Form = { Form( @@ -581,8 +570,7 @@ private[codec] object EncoderDecoder { if (inputs.length > 1) { Headers(Header.ContentType(MediaType.multipart.`form-data`)) } else { - if (isEventStream) Headers(Header.ContentType(MediaType.text.`event-stream`)) - else if (flattened.content.length < 1) Headers.empty + if (flattened.content.length < 1) Headers.empty else { val mediaType = flattened .content(0) @@ -599,14 +587,6 @@ private[codec] object EncoderDecoder { case BodyCodec.Multiple(codec, _) if codec.defaultMediaType.binary => true case _ => false } - - private def isEventStreamBody(codec: BodyCodec[_]): Boolean = - codec match { - case BodyCodec.Multiple(codec, _) - if codec.lookup(MediaType.text.`event-stream`).exists(_.schema == Schema[ServerSentEvent]) => - true - case _ => false - } } }