Skip to content

Commit

Permalink
Validation for OpenAPI generated endpoints (#2786) (#2968)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil authored Aug 15, 2024
1 parent 5294e91 commit 0255bd3
Show file tree
Hide file tree
Showing 17 changed files with 699 additions and 197 deletions.
2 changes: 1 addition & 1 deletion project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ object Dependencies {
val ZioVersion = "2.1.7"
val ZioCliVersion = "0.5.0"
val ZioJsonVersion = "0.7.1"
val ZioSchemaVersion = "1.3.0"
val ZioSchemaVersion = "1.4.1"
val SttpVersion = "3.3.18"
val ZioConfigVersion = "4.0.2"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ private[cli] object HttpOptions {

// Should Map, Sequence and Set have implementations?
// Options cannot be used to specify an arbitrary number of parameters.
case Schema.Map(_, _, _) => emptyJson
case Schema.Sequence(_, _, _, _, _) => emptyJson
case Schema.Set(_, _) => emptyJson
case Schema.Map(_, _, _) => emptyJson
case Schema.NonEmptyMap(_, _, _) => emptyJson
case Schema.Sequence(_, _, _, _, _) => emptyJson
case Schema.NonEmptySequence(_, _, _, _, _) => emptyJson
case Schema.Set(_, _) => emptyJson

case Schema.Lazy(schema0) => loop(prefix, schema0())
case Schema.Dynamic(_) => emptyJson
Expand Down
297 changes: 219 additions & 78 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala

Large diffs are not rendered by default.

40 changes: 27 additions & 13 deletions zio-http-gen/src/main/scala/zio/http/gen/scala/Code.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import scala.meta.prettyprinters.XtensionSyntax

import zio.http.{Method, Status}

import com.sun.tools.javac.code.TypeMetadata.Annotations

sealed trait Code extends Product with Serializable

object Code {
sealed trait ScalaType extends Code { self =>
def seq: Collection.Seq = Collection.Seq(self)
def set: Collection.Set = Collection.Set(self)
def map: Collection.Map = Collection.Map(self)
def opt: Collection.Opt = Collection.Opt(self)
def seq(nonEmpty: Boolean): Collection.Seq = Collection.Seq(self, nonEmpty)
def set(nonEmpty: Boolean): Collection.Set = Collection.Set(self, nonEmpty)
def map: Collection.Map = Collection.Map(self)
def opt: Collection.Opt = Collection.Opt(self)
}

object ScalaType {
Expand Down Expand Up @@ -79,17 +81,29 @@ object Code {
abstractMembers: List[Field] = Nil,
) extends ScalaType

sealed abstract case class Field private (name: String, fieldType: ScalaType) extends Code {
final case class Annotation(value: String)

sealed abstract case class Field private (name: String, fieldType: ScalaType, annotations: List[Annotation])
extends Code {
// only allow copy on fieldType, since name is mangled to be valid in smart constructor
def copy(fieldType: ScalaType): Field = new Field(name, fieldType) {}
def copy(fieldType: ScalaType = fieldType, annotations: List[Annotation] = annotations): Field =
new Field(name, fieldType, annotations) {}
}

object Field {

def apply(name: String): Field = apply(name, ScalaType.Inferred)
def apply(name: String, fieldType: ScalaType): Field = {
def apply(name: String): Field = apply(name, ScalaType.Inferred)
def apply(name: String, fieldType: ScalaType): Field = {
val validScalaTermName = Term.Name(name).syntax
new Field(validScalaTermName, fieldType, Nil) {}
}
def apply(name: String, fieldType: ScalaType, annotation: Annotation): Field = {
val validScalaTermName = Term.Name(name).syntax
new Field(validScalaTermName, fieldType, List(annotation)) {}
}
def apply(name: String, fieldType: ScalaType, annotations: List[Annotation]): Field = {
val validScalaTermName = Term.Name(name).syntax
new Field(validScalaTermName, fieldType) {}
new Field(validScalaTermName, fieldType, annotations) {}
}
}

Expand All @@ -98,10 +112,10 @@ object Code {
}

object Collection {
final case class Seq(elementType: ScalaType) extends Collection
final case class Set(elementType: ScalaType) extends Collection
final case class Map(elementType: ScalaType) extends Collection
final case class Opt(elementType: ScalaType) extends Collection
final case class Seq(elementType: ScalaType, nonEmpty: Boolean) extends Collection
final case class Set(elementType: ScalaType, nonEmpty: Boolean) extends Collection
final case class Map(elementType: ScalaType) extends Collection
final case class Opt(elementType: ScalaType) extends Collection
}

sealed trait Primitive extends ScalaType
Expand Down
25 changes: 16 additions & 9 deletions zio-http-gen/src/main/scala/zio/http/gen/scala/CodeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,16 @@ object CodeGen {
val traitBodyBuilder = new StringBuilder().append(' ')
var pre = '{'
val imports = abstractMembers.foldLeft(List.empty[Code.Import]) {
case (importsAcc, Code.Field(name, fieldType)) =>
case (importsAcc, Code.Field(name, fieldType, annotations)) =>
val (imports, tpe) = render(basePackage)(fieldType)
if (tpe.isEmpty) importsAcc
else {
traitBodyBuilder += pre
pre = '\n'
annotations.foreach { annotation =>
traitBodyBuilder ++= annotation.value
traitBodyBuilder += '\n'
}
traitBodyBuilder ++= "def "
traitBodyBuilder ++= name
traitBodyBuilder ++= ": "
Expand Down Expand Up @@ -159,24 +163,27 @@ object CodeGen {

case col: Code.Collection =>
col match {
case Code.Collection.Seq(elementType) =>
case Code.Collection.Seq(elementType, nonEmpty) =>
val (imports, tpe) = render(basePackage)(elementType)
(Code.Import("zio.Chunk") :: imports) -> s"Chunk[$tpe]"
case Code.Collection.Set(elementType) =>
if (nonEmpty) (Code.Import("zio.NonEmptyChunk") :: imports) -> s"NonEmptyChunk[$tpe]"
else (Code.Import("zio.Chunk") :: imports) -> s"Chunk[$tpe]"
case Code.Collection.Set(elementType, nonEmpty) =>
val (imports, tpe) = render(basePackage)(elementType)
imports -> s"Set[$tpe]"
case Code.Collection.Map(elementType) =>
if (nonEmpty) (Code.Import("zio.prelude.NonEmptySet") :: imports) -> s"NonEmptySet[$tpe]"
else imports -> s"Set[$tpe]"
case Code.Collection.Map(elementType) =>
val (imports, tpe) = render(basePackage)(elementType)
imports -> s"Map[String, $tpe]"
case Code.Collection.Opt(elementType) =>
case Code.Collection.Opt(elementType) =>
val (imports, tpe) = render(basePackage)(elementType)
imports -> s"Option[$tpe]"
}

case Code.Field(name, fieldType) =>
case Code.Field(name, fieldType, annotations) =>
val (imports, tpe) = render(basePackage)(fieldType)
val annotationsStr = annotations.map(_.value).mkString("\n")
val content = if (tpe.isEmpty) s"val $name" else s"val $name: $tpe"
imports -> content
imports -> (annotationsStr + content)

case Code.Primitive.ScalaBoolean => Nil -> "Boolean"
case Code.Primitive.ScalaByte => Nil -> "Byte"
Expand Down
12 changes: 6 additions & 6 deletions zio-http-gen/src/test/resources/ComponentAnimal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ object Animal {

implicit val codec: Schema[Animal] = DeriveSchema.gen[Animal]
case class Alligator(
age: Int,
weight: Float,
num_teeth: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_teeth: Int,
) extends Animal
object Alligator {

implicit val codec: Schema[Alligator] = DeriveSchema.gen[Alligator]

}
case class Zebra(
age: Int,
weight: Float,
num_stripes: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_stripes: Int,
) extends Animal
object Zebra {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ object Animal {

implicit val codec: Schema[Animal] = DeriveSchema.gen[Animal]
case class Alligator(
age: Int,
weight: Float,
num_teeth: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_teeth: Int,
) extends Animal
object Alligator {

implicit val codec: Schema[Alligator] = DeriveSchema.gen[Alligator]

}
case class Zebra(
age: Int,
weight: Float,
num_stripes: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_stripes: Int,
) extends Animal
object Zebra {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@ object Animal {

implicit val codec: Schema[Animal] = DeriveSchema.gen[Animal]
case class Alligator(
age: Int,
weight: Float,
num_teeth: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_teeth: Int,
) extends Animal
object Alligator {

implicit val codec: Schema[Alligator] = DeriveSchema.gen[Alligator]

}
case class Zebra(
age: Int,
weight: Float,
num_stripes: Int,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) age: Int,
@zio.schema.annotation.validate[Float](zio.schema.validation.Validation.greaterThan(-1.0)) weight: Float,
@zio.schema.annotation.validate[Int](zio.schema.validation.Validation.greaterThan(-1)) num_stripes: Int,
dazzle: Chunk[Zebra],
) extends Animal
object Zebra {
Expand Down
15 changes: 15 additions & 0 deletions zio-http-gen/src/test/resources/ValidatedData.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package test.component

import zio.schema._

case class ValidatedData(
@zio.schema.annotation.validate[String](zio.schema.validation.Validation.minLength(10)) name: String,
@zio.schema.annotation.validate[Int](
zio.schema.validation.Validation.greaterThan(0) && zio.schema.validation.Validation.lessThan(100),
) age: Int,
)
object ValidatedData {

implicit val codec: Schema[ValidatedData] = DeriveSchema.gen[ValidatedData]

}
33 changes: 30 additions & 3 deletions zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ import scala.meta._
import scala.meta.parsers._
import scala.util.{Failure, Success, Try}

import zio.Scope
import zio.json.{JsonDecoder, JsonEncoder}
import zio.test.Assertion.{hasSameElements, isFailure, isSuccess, throws}
import zio.test.Assertion.{hasSameElements, isFailure, isSuccess}
import zio.test.TestAspect.{blocking, flaky}
import zio.test.TestFailure.fail
import zio.test._
import zio.{Scope, ZIO}

import zio.schema.annotation.validate
import zio.schema.codec.JsonCodec
import zio.schema.validation.Validation
import zio.schema.{DeriveSchema, Schema}

import zio.http._
import zio.http.codec._
Expand All @@ -27,6 +29,14 @@ import zio.http.gen.openapi.{Config, EndpointGen}
@nowarn("msg=missing interpolator")
object CodeGenSpec extends ZIOSpecDefault {

case class ValidatedData(
@validate(Validation.maxLength(10))
name: String,
@validate(Validation.greaterThan(0) && Validation.lessThan(100))
age: Int,
)
implicit val validatedDataSchema: Schema[ValidatedData] = DeriveSchema.gen[ValidatedData]

private def fileShouldBe(dir: java.nio.file.Path, subPath: String, expectedFile: String): TestResult = {
val filePath = dir.resolve(Paths.get(subPath))
val generated = Files.readAllLines(filePath).asScala.mkString("\n")
Expand Down Expand Up @@ -791,5 +801,22 @@ object CodeGenSpec extends ZIOSpecDefault {
"/AnimalWithMap.scala",
)
},
test("Endpoint with data validation") {
val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[ValidatedData]
val openAPIJson = OpenAPIGen.fromEndpoints(endpoint).toJson
val openAPI = OpenAPI.fromJson(openAPIJson).getOrElse(OpenAPI.empty)
val code = EndpointGen.fromOpenAPI(openAPI)

val tempDir = Files.createTempDirectory("codegen")

CodeGen.writeFiles(code, java.nio.file.Paths.get(tempDir.toString, "test"), "test", Some(scalaFmtPath))

fileShouldBe(
tempDir,
"test/component/ValidatedData.scala",
"/ValidatedData.scala",
)

},
) @@ java11OrNewer @@ flaky @@ blocking // Downloading scalafmt on CI is flaky
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
}

object AsyncBodyReader {
private val FnUnit = () => ()

sealed trait State

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object AuthSpec extends ZIOSpecDefault {
val endpoint = Endpoint(Method.GET / "test").out[String](MediaType.text.`plain`)
val routes =
Routes(
endpoint.implementHandler(handler((_: Unit) => ZIO.serviceWith[AuthContext](_.value))),
endpoint.implementHandler(handler((_: Unit) => withContext((ctx: AuthContext) => ctx.value))),
) @@ basicAuthContext
val response = routes.run(
Request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2503,7 +2503,8 @@ object OpenAPIGenSpec extends ZIOSpecDefault {
| "array",
| "items" : {
| "$ref" : "#/components/schemas/Recursive"
| }
| },
| "uniqueItems" : true
| },
| "nestedEither" : {
| "oneOf" : [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ final case class HttpEndpoint(
case JsonSchema.OneOfSchema(_) => throw new Exception("OneOfSchema not supported")
case JsonSchema.AllOfSchema(_) => throw new Exception("AllOfSchema not supported")
case JsonSchema.AnyOfSchema(_) => throw new Exception("AnyOfSchema not supported")
case JsonSchema.Number(_) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Integer(_) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.String(_, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Number(_, _, _, _, _, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Integer(_, _, _, _, _, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.String(_, _, _, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Boolean => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.ArrayType(_) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.ArrayType(_, _, _) => s""""${getName(name)}": {{${getName(name)}}}"""
case JsonSchema.Object(properties, _, _) =>
if (properties.isEmpty) ""
else {
Expand Down
Loading

0 comments on commit 0255bd3

Please sign in to comment.