Skip to content

Commit

Permalink
fixed enum support in avro codec (#578)
Browse files Browse the repository at this point in the history
* fixed enum support in avro codec

* fixed tests

* linted

* fixed README

* fixed scal 3x compilation

* fixed enum support in avro codec

* fixed tests

* linted

* fixed scal 3x compilation

* fix: fixed README

* fix: fixed README

---------

Co-authored-by: Daniel Vigovszky <[email protected]>
  • Loading branch information
devsprint and vigoo authored Aug 22, 2023
1 parent 26bb289 commit 9420122
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 43 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ _ZIO Schema_ is used by a growing number of ZIO libraries, including _ZIO Flow_,
In order to use this library, we need to add the following lines in our `build.sbt` file:

```scala
libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.12"
libraryDependencies += "dev.zio" %% "zio-schema-bson" % "0.4.12"
libraryDependencies += "dev.zio" %% "zio-schema-json" % "0.4.12"
libraryDependencies += "dev.zio" %% "zio-schema-protobuf" % "0.4.12"
libraryDependencies += "dev.zio" %% "zio-schema" % "0.4.13"
libraryDependencies += "dev.zio" %% "zio-schema-bson" % "0.4.13"
libraryDependencies += "dev.zio" %% "zio-schema-json" % "0.4.13"
libraryDependencies += "dev.zio" %% "zio-schema-protobuf" % "0.4.13"

// Required for automatic generic derivation of schemas
libraryDependencies += "dev.zio" %% "zio-schema-derivation" % "0.4.12",
libraryDependencies += "dev.zio" %% "zio-schema-derivation" % "0.4.13",
libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided"
```

Expand Down
114 changes: 82 additions & 32 deletions zio-schema-avro/shared/src/main/scala/zio/schema/codec/AvroCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,31 @@ object AvroCodec {
private def decodeCaseClass1[A, Z](raw: Any, schema: Schema.CaseClass1[A, Z]) =
decodeValue(raw, schema.field.schema).map(schema.defaultConstruct)

private def decodeEnum[Z](raw: Any, cases: Schema.Case[Z, _]*): Either[DecodeError, Any] = {
val generic = raw.asInstanceOf[GenericData.Record]
val enumCaseName = generic.getSchema.getFullName
val enumCaseValue = generic.get("value")
private def decodeEnum[Z](raw: Any, cases: Schema.Case[Z, _]*): Either[DecodeError, Any] =
raw match {
case enums: GenericData.EnumSymbol =>
decodeGenericEnum(enums.toString, None, cases: _*)
case gr: GenericData.Record =>
val enumCaseName = gr.getSchema.getFullName
if (gr.hasField("value")) {
val enumCaseValue = gr.get("value")
decodeGenericEnum[Z](enumCaseName, Some(enumCaseValue), cases: _*)
} else {
decodeGenericEnum[Z](enumCaseName, None, cases: _*)
}
case _ => Left(DecodeError.MalformedFieldWithPath(Chunk.single("Error"), s"Unknown enum: $raw"))
}

private def decodeGenericEnum[Z](
enumCaseName: String,
enumCaseValue: Option[AnyRef],
cases: Schema.Case[Z, _]*
): Either[DecodeError, Any] =
cases
.find(_.id == enumCaseName)
.map(s => decodeValue(enumCaseValue, s.schema))
.map(s => decodeValue(enumCaseValue.getOrElse(s), s.schema))
.toRight(DecodeError.MalformedFieldWithPath(Chunk.single("Error"), s"Unknown enum value: $enumCaseName"))
.flatMap(identity)
}

private def decodeRecord[A](value: A, schema: Schema.Record[_]) = {
val record = value.asInstanceOf[GenericRecord]
Expand Down Expand Up @@ -454,41 +469,41 @@ object AvroCodec {
else decodeValue(value, schema).map(Some(_))

private def encodeValue[A](a: A, schema: Schema[A]): Any = schema match {
case Schema.Enum1(_, c1, _) => encodeEnum(a, c1)
case Schema.Enum2(_, c1, c2, _) => encodeEnum(a, c1, c2)
case Schema.Enum3(_, c1, c2, c3, _) => encodeEnum(a, c1, c2, c3)
case Schema.Enum4(_, c1, c2, c3, c4, _) => encodeEnum(a, c1, c2, c3, c4)
case Schema.Enum5(_, c1, c2, c3, c4, c5, _) => encodeEnum(a, c1, c2, c3, c4, c5)
case Schema.Enum6(_, c1, c2, c3, c4, c5, c6, _) => encodeEnum(a, c1, c2, c3, c4, c5, c6)
case Schema.Enum1(_, c1, _) => encodeEnum(schema, a, c1)
case Schema.Enum2(_, c1, c2, _) => encodeEnum(schema, a, c1, c2)
case Schema.Enum3(_, c1, c2, c3, _) => encodeEnum(schema, a, c1, c2, c3)
case Schema.Enum4(_, c1, c2, c3, c4, _) => encodeEnum(schema, a, c1, c2, c3, c4)
case Schema.Enum5(_, c1, c2, c3, c4, c5, _) => encodeEnum(schema, a, c1, c2, c3, c4, c5)
case Schema.Enum6(_, c1, c2, c3, c4, c5, c6, _) => encodeEnum(schema, a, c1, c2, c3, c4, c5, c6)
case Schema.Enum7(_, c1, c2, c3, c4, c5, c6, c7, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7)
case Schema.Enum8(_, c1, c2, c3, c4, c5, c6, c7, c8, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8)
case Schema.Enum9(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9)
case Schema.Enum10(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10)
case Schema.Enum11(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11)
case Schema.Enum12(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12)
case Schema.Enum13(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13)
case Schema.Enum14(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14)
case Schema.Enum15(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15)
case Schema.Enum16(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16)
case Schema.Enum17(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17)
case Schema.Enum18(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18)
case Schema.Enum19(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19)
case Schema
.Enum20(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20)
case Schema.Enum21(
_,
c1,
Expand All @@ -514,7 +529,31 @@ object AvroCodec {
c21,
_
) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21)
encodeEnum(
schema,
a,
c1,
c2,
c3,
c4,
c5,
c6,
c7,
c8,
c9,
c10,
c11,
c12,
c13,
c14,
c15,
c16,
c17,
c18,
c19,
c20,
c21
)
case Schema.Enum22(
_,
c1,
Expand Down Expand Up @@ -542,6 +581,7 @@ object AvroCodec {
_
) =>
encodeEnum(
schema,
a,
c1,
c2,
Expand Down Expand Up @@ -580,9 +620,10 @@ object AvroCodec {
case Schema.Optional(schema, _) => encodeOption(schema, a)
case Schema.Tuple2(left, right, _) =>
encodeTuple2(left.asInstanceOf[Schema[Any]], right.asInstanceOf[Schema[Any]], a)
case Schema.Either(left, right, _) => encodeEither(left, right, a)
case Schema.Lazy(schema0) => encodeValue(a, schema0())
case Schema.CaseClass0(_, _, _) => encodePrimitive((), StandardType.UnitType)
case Schema.Either(left, right, _) => encodeEither(left, right, a)
case Schema.Lazy(schema0) => encodeValue(a, schema0())
case Schema.CaseClass0(_, _, _) =>
encodeCaseClass(schema, a, Seq.empty: _*) //encodePrimitive((), StandardType.UnitType)
case Schema.CaseClass1(_, f, _, _) => encodeCaseClass(schema, a, f)
case Schema.CaseClass2(_, f0, f1, _, _) => encodeCaseClass(schema, a, f0, f1)
case Schema.CaseClass3(_, f0, f1, f2, _, _) => encodeCaseClass(schema, a, f0, f1, f2)
Expand Down Expand Up @@ -926,11 +967,20 @@ object AvroCodec {
record
}

private def encodeEnum[Z](value: Z, cases: Schema.Case[Z, _]*): Any = {
private def encodeEnum[Z](schemaRaw: Schema[Z], value: Z, cases: Schema.Case[Z, _]*): Any = {
val schema = AvroSchemaCodec
.encodeToApacheAvro(schemaRaw)
.getOrElse(throw new Exception("Avro schema could not be generated for Enum."))
val fieldIndex = cases.indexWhere(c => c.deconstructOption(value).isDefined)
if (fieldIndex >= 0) {
val subtypeCase = cases(fieldIndex)
encodeValue(subtypeCase.deconstruct(value), subtypeCase.schema.asInstanceOf[Schema[Any]])
if (schema.getType == SchemaAvro.Type.ENUM) {
GenericData.get.createEnum(schema.getEnumSymbols.get(fieldIndex), schema)
} else {

encodeValue(subtypeCase.deconstruct(value), subtypeCase.schema.asInstanceOf[Schema[Any]])

}
} else {
throw new Exception("Could not find matching case for enum value.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,8 @@ object AvroSchemaCodec extends AvroSchemaCodec {
}

def hasAvroEnumAnnotation(annotations: Chunk[Any]): Boolean = annotations.exists {
case AvroAnnotations.avroEnum => true
case _ => false
case AvroAnnotations.avroEnum() => true
case _ => false
}

def wrapAvro(schemaAvro: SchemaAvro, name: String, marker: AvroPropMarker): SchemaAvro = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import java.time.{
import java.util.UUID

import zio._
import zio.schema.codec.AvroAnnotations.avroEnum
import zio.schema.{ DeriveSchema, Schema }
import zio.stream.ZStream
import zio.test.TestAspect.failing
import zio.test._

object AvroCodecSpec extends ZIOSpecDefault {
Expand Down Expand Up @@ -106,10 +106,12 @@ object AvroCodecSpec extends ZIOSpecDefault {

case class BooleanValue(value: Boolean) extends OneOf

case object NullValue extends OneOf

implicit val schemaOneOf: Schema[OneOf] = DeriveSchema.gen[OneOf]
}

sealed trait Enums
@avroEnum() sealed trait Enums

object Enums {
case object A extends Enums
Expand Down Expand Up @@ -649,12 +651,18 @@ object AvroCodecSpec extends ZIOSpecDefault {
val result = codec.decode(bytes)
assertTrue(result == Right(OneOf.BooleanValue(true)))
},
test("Decode Enum3 - case object") {
val codec = AvroCodec.schemaBasedBinaryCodec[OneOf]
val bytes = codec.encode(OneOf.NullValue)
val result = codec.decode(bytes)
assertTrue(result == Right(OneOf.NullValue))
},
test("Decode Enum5") {
val codec = AvroCodec.schemaBasedBinaryCodec[Enums]
val bytes = codec.encode(Enums.A)
val result = codec.decode(bytes)
assertTrue(result == Right(Enums.A))
} @@ failing, // TODO: the case object from a sealed trait are not properly encoded and decoded.
},
test("Decode Person") {
val codec = AvroCodec.schemaBasedBinaryCodec[Person]
val bytes = codec.encode(Person("John", 42))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object AvroSchemaCodecSpec extends ZIOSpecDefault {
},
test("encodes sealed trait objects only as enum when avroEnum annotation is present") {

val schema = DeriveSchema.gen[SpecTestData.CaseObjectsOnlyAdt].annotate(AvroAnnotations.avroEnum)
val schema = DeriveSchema.gen[SpecTestData.CaseObjectsOnlyAdt].annotate(AvroAnnotations.avroEnum())
val result = AvroSchemaCodec.encode(schema)

val expected = """{"type":"enum","name":"MyEnum","symbols":["A","B","MyC"]}"""
Expand Down

0 comments on commit 9420122

Please sign in to comment.