Skip to content

Commit

Permalink
Merge pull request #70 from mati865/fix-avro-rs-unions
Browse files Browse the repository at this point in the history
Fix serde visitor for byts with `use_avro_rs_unions` option
  • Loading branch information
lerouxrgd authored Aug 17, 2024
2 parents e9c96c4 + 6dda5e5 commit 7e0f600
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ impl GeneratorBuilder {
/// Adds support for deserializing union types from the `apache-avro` crate.
///
/// Only necessary for unions of 3 or more types or 2-type unions without "null".
/// Note that only int, long, float, double, and boolean values are currently supported.
/// Note that only int, long, float, double, boolean and bytes values are currently supported.
pub fn use_avro_rs_unions(mut self, use_avro_rs_unions: bool) -> GeneratorBuilder {
self.use_avro_rs_unions = use_avro_rs_unions;
self
Expand Down
2 changes: 1 addition & 1 deletion src/templates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl<'de> serde::Deserialize<'de> for {{ name }} {
{%- for v in visitors %}
{%- if v.serde_visitor %}
fn visit_{{ v.serde_visitor | trim_start_matches(pat="&") }}<E>(self, value: {{ v.serde_visitor }}) -> Result<Self::Value, E>
fn visit_{{ v.serde_visitor | replace(from="&[u8]", to="bytes") | trim_start_matches(pat="&") }}<E>(self, value: {{ v.serde_visitor }}) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Expand Down
2 changes: 1 addition & 1 deletion tests/schemas/multi_valued_union_with_avro_rs_unions.avsc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"fields": [ {
"name": "extra",
"type": "map",
"values" : [ "null", "string", "long", "double", "boolean" ]
"values" : [ "null", "string", "long", "double", "boolean", "bytes" ]
} ]
}
92 changes: 59 additions & 33 deletions tests/schemas/multi_valued_union_with_avro_rs_unions.rs
Original file line number Diff line number Diff line change
@@ -1,134 +1,160 @@

/// Auto-generated type for unnamed Avro union variants.
#[derive(Debug, PartialEq, Clone, serde::Serialize)]
pub enum UnionStringLongDoubleBoolean {
pub enum UnionStringLongDoubleBooleanBytes {
String(String),
Long(i64),
Double(f64),
Boolean(bool),
Bytes(#[serde(with = "apache_avro::serde_avro_bytes")] Vec<u8>),
}

impl From<String> for UnionStringLongDoubleBoolean {
impl From<String> for UnionStringLongDoubleBooleanBytes {
fn from(v: String) -> Self {
Self::String(v)
}
}

impl TryFrom<UnionStringLongDoubleBoolean> for String {
type Error = UnionStringLongDoubleBoolean;
impl TryFrom<UnionStringLongDoubleBooleanBytes> for String {
type Error = UnionStringLongDoubleBooleanBytes;

fn try_from(v: UnionStringLongDoubleBoolean) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBoolean::String(v) = v {
fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBooleanBytes::String(v) = v {
Ok(v)
} else {
Err(v)
}
}
}

impl From<i64> for UnionStringLongDoubleBoolean {
impl From<i64> for UnionStringLongDoubleBooleanBytes {
fn from(v: i64) -> Self {
Self::Long(v)
}
}

impl TryFrom<UnionStringLongDoubleBoolean> for i64 {
type Error = UnionStringLongDoubleBoolean;
impl TryFrom<UnionStringLongDoubleBooleanBytes> for i64 {
type Error = UnionStringLongDoubleBooleanBytes;

fn try_from(v: UnionStringLongDoubleBoolean) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBoolean::Long(v) = v {
fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBooleanBytes::Long(v) = v {
Ok(v)
} else {
Err(v)
}
}
}

impl From<f64> for UnionStringLongDoubleBoolean {
impl From<f64> for UnionStringLongDoubleBooleanBytes {
fn from(v: f64) -> Self {
Self::Double(v)
}
}

impl TryFrom<UnionStringLongDoubleBoolean> for f64 {
type Error = UnionStringLongDoubleBoolean;
impl TryFrom<UnionStringLongDoubleBooleanBytes> for f64 {
type Error = UnionStringLongDoubleBooleanBytes;

fn try_from(v: UnionStringLongDoubleBoolean) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBoolean::Double(v) = v {
fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBooleanBytes::Double(v) = v {
Ok(v)
} else {
Err(v)
}
}
}

impl From<bool> for UnionStringLongDoubleBoolean {
impl From<bool> for UnionStringLongDoubleBooleanBytes {
fn from(v: bool) -> Self {
Self::Boolean(v)
}
}

impl TryFrom<UnionStringLongDoubleBoolean> for bool {
type Error = UnionStringLongDoubleBoolean;
impl TryFrom<UnionStringLongDoubleBooleanBytes> for bool {
type Error = UnionStringLongDoubleBooleanBytes;

fn try_from(v: UnionStringLongDoubleBoolean) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBoolean::Boolean(v) = v {
fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBooleanBytes::Boolean(v) = v {
Ok(v)
} else {
Err(v)
}
}
}

impl<'de> serde::Deserialize<'de> for UnionStringLongDoubleBoolean {
fn deserialize<D>(deserializer: D) -> Result<UnionStringLongDoubleBoolean, D::Error>
impl From<Vec<u8>> for UnionStringLongDoubleBooleanBytes {
fn from(v: Vec<u8>) -> Self {
Self::Bytes(v)
}
}

impl TryFrom<UnionStringLongDoubleBooleanBytes> for Vec<u8> {
type Error = UnionStringLongDoubleBooleanBytes;

fn try_from(v: UnionStringLongDoubleBooleanBytes) -> Result<Self, Self::Error> {
if let UnionStringLongDoubleBooleanBytes::Bytes(v) = v {
Ok(v)
} else {
Err(v)
}
}
}

impl<'de> serde::Deserialize<'de> for UnionStringLongDoubleBooleanBytes {
fn deserialize<D>(deserializer: D) -> Result<UnionStringLongDoubleBooleanBytes, D::Error>
where
D: serde::Deserializer<'de>,
{
/// Serde visitor for the auto-generated unnamed Avro union type.
struct UnionStringLongDoubleBooleanVisitor;
struct UnionStringLongDoubleBooleanBytesVisitor;

impl<'de> serde::de::Visitor<'de> for UnionStringLongDoubleBooleanVisitor {
type Value = UnionStringLongDoubleBoolean;
impl<'de> serde::de::Visitor<'de> for UnionStringLongDoubleBooleanBytesVisitor {
type Value = UnionStringLongDoubleBooleanBytes;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a UnionStringLongDoubleBoolean")
formatter.write_str("a UnionStringLongDoubleBooleanBytes")
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(UnionStringLongDoubleBoolean::String(value.into()))
Ok(UnionStringLongDoubleBooleanBytes::String(value.into()))
}

fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(UnionStringLongDoubleBoolean::Long(value.into()))
Ok(UnionStringLongDoubleBooleanBytes::Long(value.into()))
}

fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(UnionStringLongDoubleBoolean::Double(value.into()))
Ok(UnionStringLongDoubleBooleanBytes::Double(value.into()))
}

fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(UnionStringLongDoubleBoolean::Boolean(value.into()))
Ok(UnionStringLongDoubleBooleanBytes::Boolean(value.into()))
}

fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(UnionStringLongDoubleBooleanBytes::Bytes(value.into()))
}
}

deserializer.deserialize_any(UnionStringLongDoubleBooleanVisitor)
deserializer.deserialize_any(UnionStringLongDoubleBooleanBytesVisitor)
}
}

#[derive(Debug, PartialEq, Clone, serde::Deserialize, serde::Serialize)]
pub struct Contact {
pub extra: ::std::collections::HashMap<String, Option<UnionStringLongDoubleBoolean>>,
pub extra: ::std::collections::HashMap<String, Option<UnionStringLongDoubleBooleanBytes>>,
}

0 comments on commit 7e0f600

Please sign in to comment.