Skip to content

Commit c4cf0d9

Browse files
committed
Update derive field overwrite support
1 parent 9d2a978 commit c4cf0d9

9 files changed

+249
-33
lines changed

components/salsa-macros/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ pub fn tracked(args: TokenStream, input: TokenStream) -> TokenStream {
8080
tracked::tracked(args, input)
8181
}
8282

83-
#[proc_macro_derive(Update)]
83+
#[proc_macro_derive(Update, attributes(update))]
8484
pub fn update(input: TokenStream) -> TokenStream {
8585
let item = parse_macro_input!(input as syn::DeriveInput);
8686
match update::update_derive(item) {
8787
Ok(tokens) => tokens.into(),
88-
Err(error) => token_stream_with_error(input, error),
88+
Err(error) => error.into_compile_error().into(),
8989
}
9090
}
9191

components/salsa-macros/src/update.rs

Lines changed: 82 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
use proc_macro2::{Literal, TokenStream};
2-
use syn::spanned::Spanned;
1+
use proc_macro2::{Literal, Span, TokenStream};
2+
use syn::{parenthesized, parse::ParseStream, spanned::Spanned, Token};
33
use synstructure::BindStyle;
44

55
use crate::hygiene::Hygiene;
66

77
pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream> {
88
let hygiene = Hygiene::from2(&input);
99

10-
if let syn::Data::Union(_) = &input.data {
10+
if let syn::Data::Union(u) = &input.data {
1111
return Err(syn::Error::new_spanned(
12-
&input.ident,
12+
u.union_token,
1313
"`derive(Update)` does not support `union`",
1414
));
1515
}
@@ -27,6 +27,24 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
2727
.variants()
2828
.iter()
2929
.map(|variant| {
30+
let err = variant
31+
.ast()
32+
.attrs
33+
.iter()
34+
.filter(|attr| attr.path().is_ident("update"))
35+
.map(|attr| {
36+
syn::Error::new(
37+
attr.path().span(),
38+
"unexpected attribute `#[update]` on variant",
39+
)
40+
})
41+
.reduce(|mut acc, err| {
42+
acc.combine(err);
43+
acc
44+
});
45+
if let Some(err) = err {
46+
return Err(err);
47+
}
3048
let variant_pat = variant.pat();
3149

3250
// First check that the `new_value` has same variant.
@@ -35,7 +53,7 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
3553
.bindings()
3654
.iter()
3755
.fold(quote!(), |tokens, binding| quote!(#tokens #binding,));
38-
let make_new_value = quote_spanned! {variant.ast().ident.span()=>
56+
let make_new_value = quote! {
3957
let #new_value = if let #variant_pat = #new_value {
4058
(#make_tuple)
4159
} else {
@@ -47,40 +65,73 @@ pub(crate) fn update_derive(input: syn::DeriveInput) -> syn::Result<TokenStream>
4765
// For each field, invoke `maybe_update` recursively to update its value.
4866
// Or the results together (using `|`, not `||`, to avoid shortcircuiting)
4967
// to get the final return value.
50-
let update_fields = variant.bindings().iter().enumerate().fold(
51-
quote!(false),
52-
|tokens, (index, binding)| {
53-
let field_ty = &binding.ast().ty;
54-
let field_index = Literal::usize_unsuffixed(index);
55-
56-
let field_span = binding
57-
.ast()
58-
.ident
59-
.as_ref()
60-
.map(Spanned::span)
61-
.unwrap_or(binding.ast().span());
62-
63-
let update_field = quote_spanned! {field_span=>
64-
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update(
65-
#binding,
66-
#new_value.#field_index,
67-
)
68-
};
68+
let mut update_fields = quote!(false);
69+
for (index, binding) in variant.bindings().iter().enumerate() {
70+
let mut attrs = binding
71+
.ast()
72+
.attrs
73+
.iter()
74+
.filter(|attr| attr.path().is_ident("update"));
75+
let attr = attrs.next();
76+
if let Some(attr) = attrs.next() {
77+
return Err(syn::Error::new(
78+
attr.path().span(),
79+
"multiple #[update(with)] attributes on field",
80+
));
81+
}
6982

70-
quote! {
71-
#tokens | unsafe { #update_field }
83+
let field_ty = &binding.ast().ty;
84+
let field_index = Literal::usize_unsuffixed(index);
85+
86+
let (maybe_update, unsafe_token) = match attr {
87+
Some(attr) => {
88+
mod kw {
89+
syn::custom_keyword!(with);
90+
}
91+
attr.parse_args_with(|parser: ParseStream| {
92+
let mut content;
93+
94+
let unsafe_token = parser.parse::<Token![unsafe]>()?;
95+
parenthesized!(content in parser);
96+
let with_token = content.parse::<kw::with>()?;
97+
parenthesized!(content in content);
98+
let expr = content.parse::<syn::Expr>()?;
99+
Ok((
100+
quote_spanned! { with_token.span() => ({ let maybe_update: unsafe fn(*mut #field_ty, #field_ty) -> bool = #expr; maybe_update }) },
101+
// quote_spanned! { with_token.span() => ((#expr) as unsafe fn(*mut #field_ty, #field_ty) -> bool) },
102+
unsafe_token,
103+
))
104+
})?
72105
}
73-
},
74-
);
106+
None => {
107+
(
108+
quote!(
109+
salsa::plumbing::UpdateDispatch::<#field_ty>::maybe_update
110+
),
111+
Token![unsafe](Span::call_site()),
112+
)
113+
}
114+
};
115+
let update_field = quote! {
116+
#maybe_update(
117+
#binding,
118+
#new_value.#field_index,
119+
)
120+
};
121+
122+
update_fields = quote! {
123+
#update_fields | #unsafe_token { #update_field }
124+
};
125+
}
75126

76-
quote!(
127+
Ok(quote!(
77128
#variant_pat => {
78129
#make_new_value
79130
#update_fields
80131
}
81-
)
132+
))
82133
})
83-
.collect();
134+
.collect::<syn::Result<_>>()?;
84135

85136
let ident = &input.ident;
86137
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#[derive(salsa::Update)]
2+
union U {
3+
field: i32,
4+
}
5+
6+
#[derive(salsa::Update)]
7+
struct S {
8+
#[update(with(missing_unsafe))]
9+
bad: i32,
10+
}
11+
12+
fn missing_unsafe(_: *mut i32, _: i32) -> bool {
13+
true
14+
}
15+
16+
fn main() {}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
error: `derive(Update)` does not support `union`
2+
--> tests/compile-fail/derive_update_expansion_failure.rs:2:1
3+
|
4+
2 | union U {
5+
| ^^^^^
6+
7+
error: expected `unsafe`
8+
--> tests/compile-fail/derive_update_expansion_failure.rs:8:14
9+
|
10+
8 | #[update(with(missing_unsafe))]
11+
| ^^^^
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#[derive(salsa::Update)]
2+
struct S2<'a> {
3+
bad2: &'a str,
4+
}
5+
6+
fn main() {}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
error: lifetime may not live long enough
2+
--> tests/compile-fail/invalid_update_field.rs:1:10
3+
|
4+
1 | #[derive(salsa::Update)]
5+
| ^^^^^^^^^^^^^ requires that `'a` must outlive `'static`
6+
2 | struct S2<'a> {
7+
| -- lifetime `'a` defined here
8+
|
9+
= note: this error originates in the derive macro `salsa::Update` (in Nightly builds, run with -Z macro-backtrace for more info)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#[derive(salsa::Update)]
2+
struct S2 {
3+
#[update(unsafe(with(my_wrong_update)))]
4+
bad: i32,
5+
#[update(unsafe(with(my_wrong_update2)))]
6+
bad2: i32,
7+
#[update(unsafe(with(my_wrong_update3)))]
8+
bad3: i32,
9+
#[update(unsafe(with(true)))]
10+
bad4: &'static str,
11+
}
12+
13+
fn my_wrong_update() {}
14+
fn my_wrong_update2(_: (), _: ()) -> bool {
15+
true
16+
}
17+
fn my_wrong_update3(_: *mut i32, _: i32) -> () {}
18+
19+
fn main() {}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
error[E0308]: mismatched types
2+
--> tests/compile-fail/invalid_update_with.rs:3:26
3+
|
4+
3 | #[update(unsafe(with(my_wrong_update)))]
5+
| ---- ^^^^^^^^^^^^^^^ incorrect number of function parameters
6+
| |
7+
| expected due to this
8+
|
9+
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
10+
found fn item `fn() -> () {my_wrong_update}`
11+
12+
error[E0308]: mismatched types
13+
--> tests/compile-fail/invalid_update_with.rs:5:26
14+
|
15+
5 | #[update(unsafe(with(my_wrong_update2)))]
16+
| ---- ^^^^^^^^^^^^^^^^ expected fn pointer, found fn item
17+
| |
18+
| expected due to this
19+
|
20+
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
21+
found fn item `fn((), ()) -> bool {my_wrong_update2}`
22+
23+
error[E0308]: mismatched types
24+
--> tests/compile-fail/invalid_update_with.rs:7:26
25+
|
26+
7 | #[update(unsafe(with(my_wrong_update3)))]
27+
| ---- ^^^^^^^^^^^^^^^^ expected fn pointer, found fn item
28+
| |
29+
| expected due to this
30+
|
31+
= note: expected fn pointer `unsafe fn(*mut i32, i32) -> bool`
32+
found fn item `fn(*mut i32, i32) -> () {my_wrong_update3}`
33+
34+
error[E0308]: mismatched types
35+
--> tests/compile-fail/invalid_update_with.rs:9:26
36+
|
37+
9 | #[update(unsafe(with(true)))]
38+
| ---- ^^^^ expected fn pointer, found `bool`
39+
| |
40+
| expected due to this
41+
|
42+
= note: expected fn pointer `unsafe fn(*mut &'static str, &'static str) -> bool`
43+
found type `bool`

tests/derive_update.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//! Test that the `Update` derive works as expected
2+
3+
#[derive(salsa::Update)]
4+
struct MyInput {
5+
field: &'static str,
6+
}
7+
8+
#[derive(salsa::Update)]
9+
struct MyInput2 {
10+
#[update(unsafe(with(custom_update)))]
11+
field: &'static str,
12+
#[update(unsafe(with(|dest, data| { *dest = data; true })))]
13+
field2: &'static str,
14+
}
15+
16+
unsafe fn custom_update(dest: *mut &'static str, _data: &'static str) -> bool {
17+
unsafe { *dest = "ill-behaved for testing purposes" };
18+
true
19+
}
20+
21+
#[test]
22+
fn derived() {
23+
let mut m = MyInput { field: "foo" };
24+
assert_eq!(m.field, "foo");
25+
assert!(unsafe { salsa::Update::maybe_update(&mut m, MyInput { field: "bar" }) });
26+
assert_eq!(m.field, "bar");
27+
assert!(!unsafe { salsa::Update::maybe_update(&mut m, MyInput { field: "bar" }) });
28+
assert_eq!(m.field, "bar");
29+
}
30+
31+
#[test]
32+
fn derived_with() {
33+
let mut m = MyInput2 {
34+
field: "foo",
35+
field2: "foo",
36+
};
37+
assert_eq!(m.field, "foo");
38+
assert_eq!(m.field2, "foo");
39+
assert!(unsafe {
40+
salsa::Update::maybe_update(
41+
&mut m,
42+
MyInput2 {
43+
field: "bar",
44+
field2: "bar",
45+
},
46+
)
47+
});
48+
assert_eq!(m.field, "ill-behaved for testing purposes");
49+
assert_eq!(m.field2, "bar");
50+
assert!(unsafe {
51+
salsa::Update::maybe_update(
52+
&mut m,
53+
MyInput2 {
54+
field: "ill-behaved for testing purposes",
55+
field2: "foo",
56+
},
57+
)
58+
});
59+
assert_eq!(m.field, "ill-behaved for testing purposes");
60+
assert_eq!(m.field2, "foo");
61+
}

0 commit comments

Comments
 (0)