diff --git a/src/gen.rs b/src/gen.rs index 146c19c..01c70dc 100644 --- a/src/gen.rs +++ b/src/gen.rs @@ -323,7 +323,7 @@ impl GeneratorBuilder { /// Adds support for deserializing union types from the `apache-avro` crate. /// /// Only necessary for unions of 3 or more types or 2-type unions without "null". - /// Note that only int, long, float, double, and boolean values are currently supported. + /// Note that only int, long, float, double, boolean and bytes values are currently supported. pub fn use_avro_rs_unions(mut self, use_avro_rs_unions: bool) -> GeneratorBuilder { self.use_avro_rs_unions = use_avro_rs_unions; self diff --git a/src/templates.rs b/src/templates.rs index 2a1ac10..bbe9587 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -179,7 +179,7 @@ impl<'de> serde::Deserialize<'de> for {{ name }} { {%- for v in visitors %} {%- if v.serde_visitor %} - fn visit_{{ v.serde_visitor | trim_start_matches(pat="&") }}(self, value: {{ v.serde_visitor }}) -> Result + fn visit_{{ v.serde_visitor | replace(from="&[u8]", to="bytes") | trim_start_matches(pat="&") }}(self, value: {{ v.serde_visitor }}) -> Result where E: serde::de::Error, { diff --git a/tests/schemas/multi_valued_union_with_avro_rs_unions.avsc b/tests/schemas/multi_valued_union_with_avro_rs_unions.avsc index 044d949..fbfc9fd 100644 --- a/tests/schemas/multi_valued_union_with_avro_rs_unions.avsc +++ b/tests/schemas/multi_valued_union_with_avro_rs_unions.avsc @@ -5,6 +5,6 @@ "fields": [ { "name": "extra", "type": "map", - "values" : [ "null", "string", "long", "double", "boolean" ] + "values" : [ "null", "string", "long", "double", "boolean", "bytes" ] } ] } diff --git a/tests/schemas/multi_valued_union_with_avro_rs_unions.rs b/tests/schemas/multi_valued_union_with_avro_rs_unions.rs index 7854aef..d9bbddf 100644 --- a/tests/schemas/multi_valued_union_with_avro_rs_unions.rs +++ b/tests/schemas/multi_valued_union_with_avro_rs_unions.rs @@ -1,24 +1,25 @@ /// Auto-generated type for unnamed Avro union variants. #[derive(Debug, PartialEq, Clone, serde::Serialize)] -pub enum UnionStringLongDoubleBoolean { +pub enum UnionStringLongDoubleBooleanBytes { String(String), Long(i64), Double(f64), Boolean(bool), + Bytes(#[serde(with = "apache_avro::serde_avro_bytes")] Vec), } -impl From for UnionStringLongDoubleBoolean { +impl From for UnionStringLongDoubleBooleanBytes { fn from(v: String) -> Self { Self::String(v) } } -impl TryFrom for String { - type Error = UnionStringLongDoubleBoolean; +impl TryFrom for String { + type Error = UnionStringLongDoubleBooleanBytes; - fn try_from(v: UnionStringLongDoubleBoolean) -> Result { - if let UnionStringLongDoubleBoolean::String(v) = v { + fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result { + if let UnionStringLongDoubleBooleanBytes::String(v) = v { Ok(v) } else { Err(v) @@ -26,17 +27,17 @@ impl TryFrom for String { } } -impl From for UnionStringLongDoubleBoolean { +impl From for UnionStringLongDoubleBooleanBytes { fn from(v: i64) -> Self { Self::Long(v) } } -impl TryFrom for i64 { - type Error = UnionStringLongDoubleBoolean; +impl TryFrom for i64 { + type Error = UnionStringLongDoubleBooleanBytes; - fn try_from(v: UnionStringLongDoubleBoolean) -> Result { - if let UnionStringLongDoubleBoolean::Long(v) = v { + fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result { + if let UnionStringLongDoubleBooleanBytes::Long(v) = v { Ok(v) } else { Err(v) @@ -44,17 +45,17 @@ impl TryFrom for i64 { } } -impl From for UnionStringLongDoubleBoolean { +impl From for UnionStringLongDoubleBooleanBytes { fn from(v: f64) -> Self { Self::Double(v) } } -impl TryFrom for f64 { - type Error = UnionStringLongDoubleBoolean; +impl TryFrom for f64 { + type Error = UnionStringLongDoubleBooleanBytes; - fn try_from(v: UnionStringLongDoubleBoolean) -> Result { - if let UnionStringLongDoubleBoolean::Double(v) = v { + fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result { + if let UnionStringLongDoubleBooleanBytes::Double(v) = v { Ok(v) } else { Err(v) @@ -62,17 +63,17 @@ impl TryFrom for f64 { } } -impl From for UnionStringLongDoubleBoolean { +impl From for UnionStringLongDoubleBooleanBytes { fn from(v: bool) -> Self { Self::Boolean(v) } } -impl TryFrom for bool { - type Error = UnionStringLongDoubleBoolean; +impl TryFrom for bool { + type Error = UnionStringLongDoubleBooleanBytes; - fn try_from(v: UnionStringLongDoubleBoolean) -> Result { - if let UnionStringLongDoubleBoolean::Boolean(v) = v { + fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result { + if let UnionStringLongDoubleBooleanBytes::Boolean(v) = v { Ok(v) } else { Err(v) @@ -80,55 +81,80 @@ impl TryFrom for bool { } } -impl<'de> serde::Deserialize<'de> for UnionStringLongDoubleBoolean { - fn deserialize(deserializer: D) -> Result +impl From> for UnionStringLongDoubleBooleanBytes { + fn from(v: Vec) -> Self { + Self::Bytes(v) + } +} + +impl TryFrom for Vec { + type Error = UnionStringLongDoubleBooleanBytes; + + fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result { + if let UnionStringLongDoubleBooleanBytes::Bytes(v) = v { + Ok(v) + } else { + Err(v) + } + } +} + +impl<'de> serde::Deserialize<'de> for UnionStringLongDoubleBooleanBytes { + fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { /// Serde visitor for the auto-generated unnamed Avro union type. - struct UnionStringLongDoubleBooleanVisitor; + struct UnionStringLongDoubleBooleanBytesVisitor; - impl<'de> serde::de::Visitor<'de> for UnionStringLongDoubleBooleanVisitor { - type Value = UnionStringLongDoubleBoolean; + impl<'de> serde::de::Visitor<'de> for UnionStringLongDoubleBooleanBytesVisitor { + type Value = UnionStringLongDoubleBooleanBytes; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a UnionStringLongDoubleBoolean") + formatter.write_str("a UnionStringLongDoubleBooleanBytes") } fn visit_str(self, value: &str) -> Result where E: serde::de::Error, { - Ok(UnionStringLongDoubleBoolean::String(value.into())) + Ok(UnionStringLongDoubleBooleanBytes::String(value.into())) } fn visit_i64(self, value: i64) -> Result where E: serde::de::Error, { - Ok(UnionStringLongDoubleBoolean::Long(value.into())) + Ok(UnionStringLongDoubleBooleanBytes::Long(value.into())) } fn visit_f64(self, value: f64) -> Result where E: serde::de::Error, { - Ok(UnionStringLongDoubleBoolean::Double(value.into())) + Ok(UnionStringLongDoubleBooleanBytes::Double(value.into())) } fn visit_bool(self, value: bool) -> Result where E: serde::de::Error, { - Ok(UnionStringLongDoubleBoolean::Boolean(value.into())) + Ok(UnionStringLongDoubleBooleanBytes::Boolean(value.into())) + } + + fn visit_bytes(self, value: &[u8]) -> Result + where + E: serde::de::Error, + { + Ok(UnionStringLongDoubleBooleanBytes::Bytes(value.into())) } } - deserializer.deserialize_any(UnionStringLongDoubleBooleanVisitor) + deserializer.deserialize_any(UnionStringLongDoubleBooleanBytesVisitor) } } #[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)] pub struct Contact { - pub extra: ::std::collections::HashMap>, + pub extra: ::std::collections::HashMap>, }