Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(enum): u16/u32 discriminants #315

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 63 additions & 13 deletions borsh-derive/src/internals/attributes/item/mod.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,28 @@
use crate::internals::attributes::{BORSH, CRATE, INIT, USE_DISCRIMINANT};
use proc_macro2::Span;
use quote::ToTokens;
use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path};
use syn::{spanned::Spanned, Attribute, DeriveInput, Error, Expr, ItemEnum, Path, TypePath};

use super::{get_one_attribute, parsing};
use super::{get_one_attribute, parsing, RUST_REPR, TAG_WIDTH};

pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> {
let borsh = get_one_attribute(&derive_input.attrs)?;

if let Some(attr) = borsh {
attr.parse_nested_meta(|meta| {
if meta.path != USE_DISCRIMINANT && meta.path != INIT && meta.path != CRATE {
if meta.path != USE_DISCRIMINANT && meta.path != INIT && meta.path != CRATE && meta.path != TAG_WIDTH {
return Err(syn::Error::new(
meta.path.span(),
"`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`",
"`crate`, `use_discriminant`, `tag_width` or `init` are the only supported attributes for `borsh`",
));
}
if meta.path == USE_DISCRIMINANT {
if meta.path == USE_DISCRIMINANT || meta.path == TAG_WIDTH {
let msg = if meta.path == USE_DISCRIMINANT { "borsh(use_discriminant=<bool>)"} else { "borsh(tag_width=<u8>)"};
let _expr: Expr = meta.value()?.parse()?;
if let syn::Data::Struct(ref _data) = derive_input.data {
return Err(syn::Error::new(
derive_input.ident.span(),
"borsh(use_discriminant=<bool>) does not support structs",
format!("{msg} does not support structs"),
));
}
} else if meta.path == INIT || meta.path == CRATE {
Expand All @@ -34,14 +36,13 @@ pub fn check_attributes(derive_input: &DeriveInput) -> Result<(), Error> {
}

pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result<bool, syn::Error> {
if input.variants.len() > 256 {
if input.variants.len() > u8::MAX as usize + 1 {
return Err(syn::Error::new(
input.span(),
"up to 256 enum variants are supported",
format!("up to {} enum variants are supported", u8::MAX as usize + 1),
));
}

let attrs = &input.attrs;
let attrs: &Vec<Attribute> = &input.attrs;
let mut use_discriminant = None;
let attr = attrs.iter().find(|attr| attr.path() == BORSH);
if let Some(attr) = attr {
Expand All @@ -61,7 +62,7 @@ pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result<bool, syn::E
));
}
};
} else if meta.path == INIT || meta.path == CRATE {
} else if meta.path == INIT || meta.path == CRATE || meta.path == TAG_WIDTH {
let _value_expr: Expr = meta.value()?.parse()?;
}
Ok(())
Expand All @@ -80,6 +81,54 @@ pub(crate) fn contains_use_discriminant(input: &ItemEnum) -> Result<bool, syn::E
Ok(use_discriminant.unwrap_or(false))
}

pub(crate) fn get_maybe_rust_repr(input: &ItemEnum) -> Option<(TypePath, Span)> {
input
.attrs
.iter()
.find(|attr| attr.path() == RUST_REPR)
.map(|attr| {
attr.parse_args::<TypePath>()
.map(|value| (attr, value))
.unwrap()
})
.map(|(attr, value)| (value, attr.span()))
}

pub(crate) fn get_maybe_borsh_tag_width(
input: &ItemEnum,
) -> Result<Option<(u8, Span)>, syn::Error> {
let mut maybe_borsh_tag_width = None;
let attr = input.attrs.iter().find(|attr| attr.path() == BORSH);
let Some(attr) = attr else {
return Ok(None);
};

attr.parse_nested_meta(|meta| {
if meta.path == TAG_WIDTH {
let value_expr: Expr = meta.value()?.parse()?;
let value = value_expr.to_token_stream().to_string();
let value = value
.parse::<u8>()
.map_err(|_| syn::Error::new(value_expr.span(), "`tag_width` accepts only u8"))?;
if value > 8 {
return Err(syn::Error::new(
value_expr.span(),
"`tag_width` accepts only values from 0 to 8",
));
}
maybe_borsh_tag_width = Some((value, value_expr.span()));
} else if meta.path == INIT
|| meta.path == CRATE
|| meta.path == TAG_WIDTH
|| meta.path == USE_DISCRIMINANT
{
let _value_expr: Expr = meta.value()?.parse()?;
}
Ok(())
})?;
Ok(maybe_borsh_tag_width)
}

pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Result<Option<Path>, Error> {
let mut res = None;
let attr = attrs.iter().find(|attr| attr.path() == BORSH);
Expand All @@ -88,7 +137,8 @@ pub(crate) fn contains_initialize_with(attrs: &[Attribute]) -> Result<Option<Pat
if meta.path == INIT {
let value_expr: Path = meta.value()?.parse()?;
res = Some(value_expr);
} else if meta.path == USE_DISCRIMINANT || meta.path == CRATE {
} else if meta.path == USE_DISCRIMINANT || meta.path == CRATE || meta.path == TAG_WIDTH
{
let _value_expr: Expr = meta.value()?.parse()?;
}

Expand All @@ -107,7 +157,7 @@ pub(crate) fn get_crate(attrs: &[Attribute]) -> Result<Option<Path>, Error> {
if meta.path == CRATE {
let value_expr: Path = parsing::parse_lit_into(BORSH, CRATE, &meta)?;
res = Some(value_expr);
} else if meta.path == USE_DISCRIMINANT || meta.path == INIT {
} else if meta.path == USE_DISCRIMINANT || meta.path == INIT || meta.path == TAG_WIDTH {
let _value_expr: Expr = meta.value()?.parse()?;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
---
source: borsh-derive/src/internals/attributes/item/mod.rs
expression: actual.unwrap_err()
snapshot_kind: text
---
Error(
"`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`",
"`crate`, `use_discriminant`, `tag_width` or `init` are the only supported attributes for `borsh`",
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
---
source: borsh-derive/src/internals/attributes/item/mod.rs
expression: actual.unwrap_err()
snapshot_kind: text
---
Error(
"`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`",
"`crate`, `use_discriminant`, `tag_width` or `init` are the only supported attributes for `borsh`",
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
---
source: borsh-derive/src/internals/attributes/item/mod.rs
expression: actual.unwrap_err()
snapshot_kind: text
---
Error(
"`crate`, `use_discriminant` or `init` are the only supported attributes for `borsh`",
"`crate`, `use_discriminant`, `tag_width` or `init` are the only supported attributes for `borsh`",
)
5 changes: 5 additions & 0 deletions borsh-derive/src/internals/attributes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ pub const DESERIALIZE_WITH: Symbol = Symbol("deserialize_with", "deserialize_wit
/// crate - sub-borsh nested meta, item-level only, `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts
pub const CRATE: Symbol = Symbol("crate", "crate = ...");

/// tag_width - sub-borsh nested meta, item-level only attribute in `BorshSerialize`, `BorshDeserialize`, `BorshSchema` contexts
pub const TAG_WIDTH: Symbol = Symbol("tag_width", "tag_width = ...");

pub const RUST_REPR: Symbol = Symbol("repr", "repr(...)");

#[cfg(feature = "schema")]
pub mod schema_keys {
use super::Symbol;
Expand Down
67 changes: 46 additions & 21 deletions borsh-derive/src/internals/deserialize/enums/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Fields, ItemEnum, Path, Variant};
use syn::{Fields, ItemEnum, Path, TypePath, Variant};

use crate::internals::{attributes::item, deserialize, enum_discriminant::Discriminants, generics};

Expand All @@ -11,14 +11,21 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
let mut where_clause = generics::default_where(where_clause);
let mut variant_arms = TokenStream2::new();
let use_discriminant = item::contains_use_discriminant(input)?;
let discriminants = Discriminants::new(&input.variants);
let maybe_borsh_tag_width = item::get_maybe_borsh_tag_width(input)?;
let maybe_rust_repr = item::get_maybe_rust_repr(input);
let discriminants = Discriminants::new(
&input.variants,
maybe_borsh_tag_width,
maybe_rust_repr,
use_discriminant,
)?;
let mut generics_output = deserialize::GenericsOutput::new(&generics);

let discriminant_type = discriminants.discriminant_type();
for (variant_idx, variant) in input.variants.iter().enumerate() {
let variant_body = process_variant(variant, &cratename, &mut generics_output)?;
let variant_ident = &variant.ident;

let discriminant_value = discriminants.get(variant_ident, use_discriminant, variant_idx)?;
let discriminant_value = discriminants.get(variant_ident, variant_idx)?;
variant_arms.extend(quote! {
if variant_tag == #discriminant_value { #name::#variant_ident #variant_body } else
});
Expand All @@ -32,30 +39,48 @@ pub fn process(input: &ItemEnum, cratename: Path) -> syn::Result<TokenStream2> {
};
generics_output.extend(&mut where_clause, &cratename);

Ok(quote! {
let deserialize_variant = quote! {
let mut return_value =
#variant_arms {
return Err(#cratename::io::Error::new(
#cratename::io::ErrorKind::InvalidData,
#cratename::__private::maybestd::format!("Unexpected variant tag: {:?}", variant_tag),
))
};
#init
Ok(return_value)
};

let deserialize = quote! {
impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause {
fn deserialize_reader<__R: #cratename::io::Read>(reader: &mut __R) -> ::core::result::Result<Self, #cratename::io::Error> {
let tag = <u8 as #cratename::de::BorshDeserialize>::deserialize_reader(reader)?;
<Self as #cratename::de::EnumExt>::deserialize_variant(reader, tag)
let variant_tag = <#discriminant_type as #cratename::de::BorshDeserialize>::deserialize_reader(reader)?;
#deserialize_variant
}
}
};

impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause {
fn deserialize_variant<__R: #cratename::io::Read>(
reader: &mut __R,
variant_tag: u8,
) -> ::core::result::Result<Self, #cratename::io::Error> {
let mut return_value =
#variant_arms {
return Err(#cratename::io::Error::new(
#cratename::io::ErrorKind::InvalidData,
#cratename::__private::maybestd::format!("Unexpected variant tag: {:?}", variant_tag),
))
};
#init
Ok(return_value)
let impl_trait = if discriminant_type.path.get_ident()
== (syn::parse_str::<TypePath>("u8").unwrap().path.get_ident())
{
quote! {
impl #impl_generics #cratename::de::EnumExt for #name #ty_generics #where_clause {
fn deserialize_variant<__R: #cratename::io::Read>(
reader: &mut __R,
variant_tag: u8,
) -> ::core::result::Result<Self, #cratename::io::Error> {
#deserialize_variant
}
}
}
} else {
quote! {}
};

Ok(quote! {
#deserialize

#impl_trait
})
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
---
source: borsh-derive/src/internals/deserialize/enums/mod.rs
expression: pretty_print_syn_str(&actual).unwrap()
snapshot_kind: text
---
impl borsh::de::BorshDeserialize for X {
fn deserialize_reader<__R: borsh::io::Read>(
reader: &mut __R,
) -> ::core::result::Result<Self, borsh::io::Error> {
let tag = <u8 as borsh::de::BorshDeserialize>::deserialize_reader(reader)?;
<Self as borsh::de::EnumExt>::deserialize_variant(reader, tag)
let variant_tag = <u8 as borsh::de::BorshDeserialize>::deserialize_reader(
reader,
)?;
let mut return_value = if variant_tag == 0u8 {
X::A
} else if variant_tag == 1u8 {
X::B
} else if variant_tag == 2u8 {
X::C
} else if variant_tag == 3u8 {
X::D
} else if variant_tag == 4u8 {
X::E
} else if variant_tag == 5u8 {
X::F
} else {
return Err(
borsh::io::Error::new(
borsh::io::ErrorKind::InvalidData,
borsh::__private::maybestd::format!(
"Unexpected variant tag: {:?}", variant_tag
),
),
)
};
Ok(return_value)
}
}
impl borsh::de::EnumExt for X {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
---
source: borsh-derive/src/internals/deserialize/enums/mod.rs
expression: pretty_print_syn_str(&actual).unwrap()
snapshot_kind: text
---
impl borsh::de::BorshDeserialize for X {
fn deserialize_reader<__R: borsh::io::Read>(
reader: &mut __R,
) -> ::core::result::Result<Self, borsh::io::Error> {
let tag = <u8 as borsh::de::BorshDeserialize>::deserialize_reader(reader)?;
<Self as borsh::de::EnumExt>::deserialize_variant(reader, tag)
let variant_tag = <u8 as borsh::de::BorshDeserialize>::deserialize_reader(
reader,
)?;
let mut return_value = if variant_tag == 0 {
X::A
} else if variant_tag == 20 {
X::B
} else if variant_tag == 20 + 1 {
X::C
} else if variant_tag == 20 + 1 + 1 {
X::D
} else if variant_tag == 10 {
X::E
} else if variant_tag == 10 + 1 {
X::F
} else {
return Err(
borsh::io::Error::new(
borsh::io::ErrorKind::InvalidData,
borsh::__private::maybestd::format!(
"Unexpected variant tag: {:?}", variant_tag
),
),
)
};
Ok(return_value)
}
}
impl borsh::de::EnumExt for X {
Expand Down
Loading
Loading