Skip to content

Commit 2f3ba19

Browse files
committed
WIP add skip_if macro attribute
WIP Default impl
1 parent 664714c commit 2f3ba19

File tree

10 files changed

+243
-33
lines changed

10 files changed

+243
-33
lines changed

protocol-derive/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ proc-macro = true
1616
# N.B. protocol-derive should not depend on the protocol crate.
1717
# This causes releasing to be a pain - which one first - neither is possible!
1818
[dependencies]
19-
syn = { version = "1.0.60", features = ["default", "extra-traits"] }
19+
syn = { version = "1.0.60", features = ["default", "extra-traits", "parsing"] }
2020
quote = "1.0.9"
2121
proc-macro2 = "1.0.24"
2222

protocol-derive/src/attr.rs

+86-12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use crate::format::{self, Format};
22

33
use proc_macro2::{Span, TokenStream};
44
use syn;
5+
use syn::{ExprPath, ExprBinary, ExprUnary, Expr};
6+
use quote::ToTokens;
57

68
#[derive(Debug)]
79
pub enum Protocol {
@@ -13,6 +15,35 @@ pub enum Protocol {
1315
prefix_subfield_names: Vec<syn::Ident>,
1416
},
1517
FixedLength(usize),
18+
SkipIf(SkipExpression),
19+
}
20+
21+
#[derive(Clone, Debug, PartialEq, Eq)]
22+
pub enum SkipExpression {
23+
PathExp(ExprPath),
24+
BinaryExp(ExprBinary),
25+
UnaryExp(ExprUnary),
26+
}
27+
28+
impl SkipExpression {
29+
pub fn parse_from(exp: &str) -> SkipExpression {
30+
let expr = syn::parse_str::<Expr>(exp).unwrap();
31+
32+
match expr {
33+
Expr::Binary(e) => SkipExpression::BinaryExp(e),
34+
Expr::Unary(e) => SkipExpression::UnaryExp(e),
35+
Expr::Path(e) => SkipExpression::PathExp(e),
36+
_ => panic!("Unexpected skip expression")
37+
}
38+
}
39+
40+
pub fn to_token_stream(&self) -> TokenStream {
41+
match self {
42+
SkipExpression::PathExp(e) => e.to_token_stream(),
43+
SkipExpression::BinaryExp(e) => e.to_token_stream(),
44+
SkipExpression::UnaryExp(ref e) => e.to_token_stream()
45+
}
46+
}
1647
}
1748

1849
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
@@ -37,7 +68,7 @@ pub fn repr(attrs: &[syn::Attribute]) -> Option<syn::Ident> {
3768
}
3869

3970
pub fn protocol(attrs: &[syn::Attribute])
40-
-> Option<Protocol> {
71+
-> Option<Protocol> {
4172
let meta_list = attrs.iter().filter_map(|attr| match attr.parse_meta() {
4273
Ok(syn::Meta::List(meta_list)) => {
4374
if meta_list.path.get_ident() == Some(&syn::Ident::new("protocol", proc_macro2::Span::call_site())) {
@@ -46,17 +77,31 @@ pub fn protocol(attrs: &[syn::Attribute])
4677
// Unrelated attribute.
4778
None
4879
}
49-
},
80+
}
5081
_ => None,
5182
}).next();
5283

53-
let meta_list: syn::MetaList = if let Some(meta_list) = meta_list { meta_list } else { return None };
84+
let meta_list: syn::MetaList = if let Some(meta_list) = meta_list { meta_list } else { return None; };
5485
let mut nested_metas = meta_list.nested.into_iter();
5586

5687
match nested_metas.next() {
5788
Some(syn::NestedMeta::Meta(syn::Meta::List(nested_list))) => {
5889
match &nested_list.path.get_ident().expect("meta is not an ident").to_string()[..] {
5990
// #[protocol(length_prefix(<kind>(<prefix field name>)))]
91+
"skip_if" => {
92+
let expression = expect::meta_list::single_element(nested_list).unwrap();
93+
let expression = match expression {
94+
syn::NestedMeta::Lit(syn::Lit::Str(s)) => {
95+
SkipExpression::parse_from(&s.value())
96+
}
97+
syn::NestedMeta::Meta(syn::Meta::Path(path)) => {
98+
todo!("Path literal not implemented yet")
99+
}
100+
_ => panic!("OH no! ! ")
101+
};
102+
103+
Some(Protocol::SkipIf(expression))
104+
}
60105
"fixed_length" => {
61106
let nested_list = expect::meta_list::single_literal(nested_list)
62107
.expect("expected a nested list");
@@ -71,7 +116,7 @@ pub fn protocol(attrs: &[syn::Attribute])
71116
}
72117
"length_prefix" => {
73118
let nested_list = expect::meta_list::nested_list(nested_list)
74-
.expect("expected a nested list");
119+
.expect("expected a nested list");
75120
let prefix_kind = match &nested_list.path.get_ident().expect("nested list is not an ident").to_string()[..] {
76121
"bytes" => LengthPrefixKind::Bytes,
77122
"elements" => LengthPrefixKind::Elements,
@@ -82,9 +127,9 @@ pub fn protocol(attrs: &[syn::Attribute])
82127
let (prefix_field_name, prefix_subfield_names) = match length_prefix_expr {
83128
syn::NestedMeta::Lit(syn::Lit::Str(s)) => {
84129
let mut parts: Vec<_> = s.value()
85-
.split(".")
86-
.map(|s| syn::Ident::new(s, Span::call_site()))
87-
.collect();
130+
.split(".")
131+
.map(|s| syn::Ident::new(s, Span::call_site()))
132+
.collect();
88133

89134
if parts.len() < 1 {
90135
panic!("there must be at least one field mentioned");
@@ -94,7 +139,7 @@ pub fn protocol(attrs: &[syn::Attribute])
94139
let subfield_idents = parts.into_iter().collect();
95140

96141
(field_ident, subfield_idents)
97-
},
142+
}
98143
syn::NestedMeta::Meta(syn::Meta::Path(path)) => match path.get_ident() {
99144
Some(field_ident) => (field_ident.clone(), Vec::new()),
100145
None => panic!("path is not an ident"),
@@ -103,15 +148,15 @@ pub fn protocol(attrs: &[syn::Attribute])
103148
};
104149

105150
Some(Protocol::LengthPrefix { kind: prefix_kind, prefix_field_name, prefix_subfield_names })
106-
},
151+
}
107152
"discriminator" => {
108153
let literal = expect::meta_list::single_literal(nested_list)
109-
.expect("expected a single literal");
154+
.expect("expected a single literal");
110155
Some(Protocol::Discriminator(literal))
111-
},
156+
}
112157
name => panic!("#[protocol({})] is not valid", name),
113158
}
114-
},
159+
}
115160
Some(syn::NestedMeta::Meta(syn::Meta::NameValue(name_value))) => {
116161
match name_value.path.get_ident() {
117162
Some(ident) => {
@@ -198,3 +243,32 @@ mod attribute {
198243
}
199244
}
200245

246+
#[cfg(test)]
247+
mod test {
248+
use crate::attr::SkipExpression;
249+
250+
#[test]
251+
fn should_parse_skip_expression() {
252+
let binary = "a == b";
253+
let parse_result = SkipExpression::parse_from(binary);
254+
assert!(matches!(parse_result, SkipExpression::BinaryExp(_)));
255+
256+
let unary = "!b";
257+
let parse_result = SkipExpression::parse_from(unary);
258+
assert!(matches!(parse_result, SkipExpression::UnaryExp(_)));
259+
260+
let path = "hello";
261+
let parse_result = SkipExpression::parse_from(path);
262+
assert!(matches!(parse_result, SkipExpression::PathExp(_)));
263+
}
264+
265+
#[test]
266+
fn should_convert_expression_to_token() {
267+
let binary = "a == b";
268+
let parse_result = SkipExpression::parse_from(binary);
269+
let tokens = parse_result.to_token_stream();
270+
let expression = quote! { #tokens };
271+
assert_eq!(expression.to_string(), "a == b");
272+
}
273+
}
274+

protocol-derive/src/codegen/enums.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ pub fn read_variant(plan: &plan::Enum)
4747
let discriminator_match_branches = plan.variants.iter().map(|variant| {
4848
let variant_name = &variant.ident;
4949
let discriminator_literal = variant.discriminator_literal();
50-
let initializer = codegen::read_fields(&variant.fields);
50+
let initializer = codegen::read_enum_fields(&variant.fields);
5151

5252
quote! {
5353
#discriminator_literal => {

protocol-derive/src/codegen/mod.rs

+82-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
use proc_macro2::TokenStream;
22
use syn;
3-
use syn::Field;
3+
use syn::{Field};
44

55
use crate::attr;
6+
use syn::spanned::Spanned;
67

78
pub mod enums;
89

9-
pub fn read_fields(fields: &syn::Fields)
10-
-> TokenStream {
10+
pub fn read_struct_field(fields: &syn::Fields)
11+
-> TokenStream {
1112
match *fields {
12-
syn::Fields::Named(ref fields_named) => read_named_fields(fields_named),
13+
syn::Fields::Named(ref fields_named) => read_named_fields_struct(fields_named),
14+
syn::Fields::Unnamed(ref fields_unnamed) => read_unnamed_fields(fields_unnamed),
15+
syn::Fields::Unit => quote!(),
16+
}
17+
}
18+
19+
pub fn read_enum_fields(fields: &syn::Fields)
20+
-> TokenStream {
21+
match *fields {
22+
syn::Fields::Named(ref fields_named) => read_named_fields_enum(fields_named),
1323
syn::Fields::Unnamed(ref fields_unnamed) => read_unnamed_fields(fields_unnamed),
1424
syn::Fields::Unit => quote!(),
1525
}
@@ -24,13 +34,56 @@ pub fn write_fields(fields: &syn::Fields)
2434
}
2535
}
2636

37+
pub fn name_fields_declarations(fields: &syn::Fields) -> TokenStream {
38+
if let syn::Fields::Named(ref fields_named) = fields {
39+
let fields_variables: Vec<TokenStream> = fields_named.named.iter().map(|field| {
40+
let field_name = &field.ident;
41+
let field_ty = &field.ty;
42+
// This field may store the length prefix of one or more other field.
43+
let update_hints = update_hints_after_read(field, &fields_named.named);
44+
let update_hints_fixed = update_hint_fixed_length(field, &fields_named.named);
45+
46+
if let Some(skip_condition) = maybe_skip(field.clone()) {
47+
quote! {
48+
#update_hints_fixed
49+
let skip_condition_path = if let Ok(skip_condition) = #skip_condition {
50+
skip_condition
51+
} else {
52+
false
53+
};
54+
55+
__hints.set_skip(skip_condition_path);
56+
let #field_name = protocol::Parcel::read_field(__io_reader, __settings, &mut __hints);
57+
let res = &#field_name;
58+
#update_hints
59+
__hints.next_field();
60+
}
61+
} else {
62+
quote! {
63+
#update_hints_fixed
64+
let #field_name: Result<#field_ty, _> = protocol::Parcel::read_field(__io_reader, __settings, &mut __hints);
65+
let res = &#field_name;
66+
#update_hints
67+
__hints.next_field();
68+
}
69+
}
70+
}).collect();
71+
72+
quote! {
73+
#( #fields_variables)*
74+
}
75+
} else {
76+
quote!()
77+
}
78+
}
79+
2780
/// Generates code that builds a initializes
2881
/// an item with named fields by parsing
2982
/// each of the fields.
3083
///
3184
/// Returns `{ ..field initializers.. }`.
32-
fn read_named_fields(fields_named: &syn::FieldsNamed)
33-
-> TokenStream {
85+
fn read_named_fields_enum(fields_named: &syn::FieldsNamed)
86+
-> TokenStream {
3487
let field_initializers: Vec<_> = fields_named.named.iter().map(|field| {
3588
let field_name = &field.ident;
3689
let field_ty = &field.ty;
@@ -52,6 +105,21 @@ fn read_named_fields(fields_named: &syn::FieldsNamed)
52105
quote! { { #( #field_initializers ),* } }
53106
}
54107

108+
/// Generates code that builds a initializes
109+
/// an item with named fields by parsing
110+
/// each of the fields.
111+
///
112+
/// Returns `{ ..field initializers.. }`.
113+
fn read_named_fields_struct(fields_named: &syn::FieldsNamed)
114+
-> TokenStream {
115+
let field_initializers: Vec<_> = fields_named.named.iter().map(|field| {
116+
let field_name = &field.ident;
117+
quote! { #field_name: #field_name? }
118+
}).collect();
119+
120+
quote! { { #( #field_initializers ),* } }
121+
}
122+
55123
fn update_hints_after_read<'a>(field: &'a syn::Field,
56124
fields: impl IntoIterator<Item=&'a syn::Field> + Clone)
57125
-> TokenStream {
@@ -89,6 +157,14 @@ fn update_hint_fixed_length<'a>(field: &'a syn::Field,
89157
}
90158
}
91159

160+
fn maybe_skip(field: syn::Field) -> Option<TokenStream> {
161+
if let Some(attr::Protocol::SkipIf(expr)) = attr::protocol(&field.attrs) {
162+
Some(expr.to_token_stream())
163+
} else {
164+
None
165+
}
166+
}
167+
92168
fn update_hints_after_write<'a>(field: &'a syn::Field,
93169
fields: impl IntoIterator<Item=&'a syn::Field> + Clone)
94170
-> TokenStream {

protocol-derive/src/lib.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ fn build_generics(ast: &syn::DeriveInput) -> (Vec<proc_macro2::TokenStream>, Vec
7373
fn impl_parcel_for_struct(ast: &syn::DeriveInput,
7474
strukt: &syn::DataStruct) -> proc_macro2::TokenStream {
7575
let strukt_name = &ast.ident;
76-
let read_fields = codegen::read_fields(&strukt.fields);
76+
let read_fields = codegen::read_struct_field(&strukt.fields);
7777
let write_fields = codegen::write_fields(&strukt.fields);
78+
let named_field_variables = codegen::name_fields_declarations(&strukt.fields);
7879

7980
impl_trait_for(ast, quote!(protocol::Parcel), quote! {
8081
const TYPE_NAME: &'static str = stringify!(#strukt_name);
@@ -87,7 +88,7 @@ fn impl_parcel_for_struct(ast: &syn::DeriveInput,
8788
// Each type gets its own hints.
8889
let mut __hints = protocol::hint::Hints::default();
8990
__hints.begin_fields();
90-
91+
#named_field_variables
9192
Ok(#strukt_name # read_fields)
9293
}
9394

protocol/src/hint.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub type FieldIndex = usize;
55
/// Hints given when reading parcels.
66
#[derive(Clone, Debug, PartialEq)]
77
pub struct Hints {
8+
pub skip_hint: Option<bool>,
89
pub current_field_index: Option<FieldIndex>,
910
/// The fields for which a length prefix
1011
/// was already present earlier in the layout.
@@ -31,6 +32,7 @@ pub enum LengthPrefixKind {
3132
impl Default for Hints {
3233
fn default() -> Self {
3334
Hints {
35+
skip_hint: None,
3436
current_field_index: None,
3537
known_field_lengths: HashMap::new(),
3638
}
@@ -60,7 +62,7 @@ mod protocol_derive_helpers {
6062
#[doc(hidden)]
6163
pub fn next_field(&mut self) {
6264
*self.current_field_index.as_mut()
63-
.expect("cannot increment next field when not in a struct")+= 1;
65+
.expect("cannot increment next field when not in a struct") += 1;
6466
}
6567

6668
// Sets the length of a variable-sized field by its 0-based index.
@@ -71,6 +73,13 @@ mod protocol_derive_helpers {
7173
kind: LengthPrefixKind) {
7274
self.known_field_lengths.insert(field_index, FieldLength { kind, length });
7375
}
76+
77+
// A type skipped is assumed to be Option<T>, we need to set this to bypass
78+
// the default Option read method
79+
#[doc(hidden)]
80+
pub fn set_skip(&mut self, do_skip: bool) {
81+
self.skip_hint = Some(do_skip);
82+
}
7483
}
7584
}
7685

0 commit comments

Comments
 (0)