Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion utils/yoke/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::spanned::Spanned;
use syn::{parse_macro_input, parse_quote, DeriveInput, Ident, Lifetime, Type, WherePredicate};
use syn::{
parse_macro_input, parse_quote, DeriveInput, Ident, Lifetime, Path, Type, WherePredicate,
};
use synstructure::Structure;

mod visitor;
Expand Down Expand Up @@ -93,7 +95,11 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 {
.to_compile_error();
}
let name = &input.ident;
let attr_path: Path = syn::parse_str("yoke").unwrap();
let manual_covariance = input.attrs.iter().any(|a| {
if a.path != attr_path {
return false;
}
if let Ok(i) = a.parse_args::<Ident>() {
if i == "prove_covariance_manually" {
return true;
Expand Down
52 changes: 50 additions & 2 deletions utils/zerovec/derive/examples/derives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@

use zerovec::ule::AsULE;
use zerovec::ule::EncodeAsVarULE;
use zerovec::ule::VarULE;
use zerovec::ule::ULE;
use zerovec::*;

fn validate_foo_ule(ule: &FooULE) -> Result<(), ZeroVecError> {
if ule.a == 0 {
return Err(ZeroVecError::parse::<Foo>());
}
Ok(())
}

#[repr(packed)]
#[derive(ule::ULE, Copy, Clone)]
#[derive(ule::ULE, Copy, Clone, Debug)]
#[ule(validate_with = "validate_foo_ule")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this should be zerovec::ule. we do that consistently elsewhere.

It's also fine for it to be zerovec::validate_with, or zerovec(validate_with =) though that last pattern isn't used so far in this crate

pub struct FooULE {
a: u8,
b: <u32 as AsULE>::ULE,
Expand Down Expand Up @@ -40,8 +50,17 @@ impl AsULE for Foo {
}
}

fn validate_relation_ule(ule: &RelationULE) -> Result<(), ZeroVecError> {
// A synthetic condition to test this code path
if ule.andor_polarity_operand == 0 && ule.modulo.as_unsigned_int() == 0 {
return Err(ZeroVecError::parse::<Relation>());
}
Ok(())
}

#[repr(packed)]
#[derive(ule::VarULE)]
#[derive(ule::VarULE, Debug)]
#[varule(validate_with = "validate_relation_ule")]
pub struct RelationULE {
/// This maps to (AndOr, Polarity, Operand),
/// with the first bit mapping to AndOr (1 == And), the second bit
Expand Down Expand Up @@ -110,6 +129,13 @@ const TEST_SLICE2: &[Foo] = &[
c: '±',
},
];

const BAD_SLICE: &[Foo] = &[Foo {
a: 0,
b: 0,
c: '\0',
}];

fn test_zerovec() {
let zerovec: ZeroVec<Foo> = TEST_SLICE.iter().copied().collect();

Expand All @@ -119,6 +145,10 @@ fn test_zerovec() {
let reparsed: ZeroVec<Foo> = ZeroVec::parse_byte_slice(bytes).expect("Parsing should succeed");

assert_eq!(reparsed, TEST_SLICE);

let bad_foo_vec: ZeroVec<Foo> = BAD_SLICE.iter().copied().collect();
let bad_bytes = bad_foo_vec.as_bytes();
FooULE::parse_byte_slice(bad_bytes).expect_err("Should fail validation");
}

fn test_varzerovec() {
Expand Down Expand Up @@ -149,6 +179,24 @@ fn test_varzerovec() {
for (ule, stack) in recovered.iter().zip(relations.iter()) {
assert_eq!(*stack, ule.as_relation());
}

let bad_relation = Relation {
andor_polarity_operand: 0,
modulo: 0,
range_list: ZeroVec::default(),
};
let bad_relationule = zerovec::ule::encode_varule_to_box(&bad_relation);
let bad_bytes = bad_relationule.as_byte_slice();
RelationULE::parse_byte_slice(bad_bytes).expect_err("Should fail validation");

let bad_relation = Relation {
andor_polarity_operand: 1,
modulo: 5004,
range_list: BAD_SLICE.iter().copied().collect(),
};
let bad_relationule = zerovec::ule::encode_varule_to_box(&bad_relation);
let bad_bytes = bad_relationule.as_byte_slice();
RelationULE::parse_byte_slice(bad_bytes).expect_err("Should fail validation");
}

fn main() {
Expand Down
4 changes: 2 additions & 2 deletions utils/zerovec/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ mod utils;
mod varule;

/// Full docs for this proc macro can be found on the [`zerovec`](docs.rs/zerovec) crate.
#[proc_macro_derive(ULE)]
#[proc_macro_derive(ULE, attributes(ule))]
pub fn ule_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
TokenStream::from(ule::derive_impl(&input))
}

/// Full docs for this proc macro can be found on the [`zerovec`](docs.rs/zerovec) crate.
#[proc_macro_derive(VarULE)]
#[proc_macro_derive(VarULE, attributes(varule))]
pub fn varule_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
TokenStream::from(varule::derive_impl(&input, None))
Expand Down
16 changes: 15 additions & 1 deletion utils/zerovec/derive/src/ule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use quote::quote;

use crate::utils::{self, FieldInfo};
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Error};
use syn::{Data, DeriveInput, Error, Path};

pub fn derive_impl(input: &DeriveInput) -> TokenStream2 {
if !utils::has_valid_repr(&input.attrs, |r| r == "packed" || r == "transparent") {
Expand Down Expand Up @@ -46,6 +46,19 @@ pub fn derive_impl(input: &DeriveInput) -> TokenStream2 {

let name = &input.ident;

let attr_path: Path = syn::parse_str("ule").unwrap();
let validate_with = match utils::find_validate_with_path(&input.attrs, &attr_path) {
Ok(Some(fn_path)) => {
quote! {
#fn_path(unsafe { &Self::from_byte_slice_unchecked(chunk)[0] })?;
}
}
Ok(None) => {
quote! {}
}
Err(e) => return e.to_compile_error(),
};

// Safety (based on the safety checklist on the ULE trait):
// 1. #name does not include any uninitialized or padding bytes.
// (achieved by enforcing #[repr(transparent)] or #[repr(packed)] on a struct of only ULE types)
Expand All @@ -68,6 +81,7 @@ pub fn derive_impl(input: &DeriveInput) -> TokenStream2 {
#[allow(clippy::indexing_slicing)] // We're slicing a chunk of known size
for chunk in bytes.chunks_exact(SIZE) {
#validators
#validate_with
debug_assert_eq!(#remaining_offset, SIZE);
}
Ok(())
Expand Down
45 changes: 44 additions & 1 deletion utils/zerovec/derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use proc_macro2::TokenStream as TokenStream2;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::{parenthesized, parse2, Attribute, Error, Field, Fields, Ident, Index, Result, Token};
use syn::{
parenthesized, parse2, Attribute, Error, Field, Fields, Ident, Index, Lit, Meta, MetaNameValue,
Path, Result, Token,
};

// Check that there are repr attributes satisfying the given predicate
pub fn has_valid_repr(attrs: &[Attribute], predicate: impl Fn(&Ident) -> bool + Copy) -> bool {
Expand Down Expand Up @@ -275,3 +278,43 @@ pub fn extract_attributes_common(

Ok(attrs)
}

pub(crate) fn find_validate_with_path(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: This should be a part of the attribute parsing above in extract_attributes_common

attrs: &[Attribute],
attr_path: &Path,
) -> Result<Option<Path>> {
let mut validate_with: Option<Path> = None;
let validate_with_path: Path = syn::parse_str("validate_with").unwrap();
for attr in attrs.iter() {
if &attr.path != attr_path {
continue;
}
match attr.parse_args::<Meta>() {
Ok(Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(lit_str),
..
})) if path == validate_with_path => {
if validate_with.is_some() {
return Err(Error::new(attr.span(), "multiple varule are not allowed"));
}
validate_with = match syn::parse_str(&lit_str.value()) {
Ok(p) => Some(p),
Err(_) => {
return Err(Error::new(
attr.span(),
"varule value must be a path to a function",
));
}
}
}
_ => {
return Err(Error::new(
attr.span(),
"varule takes a single name value, validate_with = \"fn_path\"",
));
}
}
}
Ok(validate_with)
}
16 changes: 15 additions & 1 deletion utils/zerovec/derive/src/varule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use proc_macro2::Span;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Error, Ident};
use syn::{Data, DeriveInput, Error, Ident, Path};

/// Implementation for derive(VarULE). `custom_varule_validator` validates the last field bytes `last_field_bytes`
/// if specified, if not, the VarULE implementation will be used.
Expand Down Expand Up @@ -86,6 +86,19 @@ pub fn derive_impl(
quote!(<#unsized_field as zerovec::ule::VarULE>::validate_byte_slice(last_field_bytes)?;)
};

let attr_path: Path = syn::parse_str("varule").unwrap();
let validate_with = match utils::find_validate_with_path(&input.attrs, &attr_path) {
Ok(Some(fn_path)) => {
quote! {
#fn_path(unsafe { Self::from_byte_slice_unchecked(bytes) })?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought: for ULE this should be free but for VarULE this will be a couple extra instructions. i suppose we don't care about that?

this way of doing it does make the derive code simpler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you comparing against?

Note that the case where validate_with is not used, there is no code being added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, sorry, here we construct an &Self (involving a minor amount of ptr metadata arithmetic) and immediately throw it away, we could theoretically merge it with the last_field_validator but doing that consistently and well will be a mess so it doesn't really matter anyway

}
}
Ok(None) => {
quote! {}
}
Err(e) => return e.to_compile_error(),
};

// Safety (based on the safety checklist on the ULE trait):
// 1. #name does not include any uninitialized or padding bytes
// (achieved by enforcing #[repr(transparent)] or #[repr(packed)] on a struct of only ULE types)
Expand All @@ -111,6 +124,7 @@ pub fn derive_impl(
#[allow(clippy::indexing_slicing)] // TODO explain
let last_field_bytes = &bytes[#remaining_offset..];
#last_field_validator
#validate_with
Ok(())
}
#[inline]
Expand Down