Skip to content

Commit 8f87b86

Browse files
authored
refactor #[setter] argument extraction (#4002)
1 parent 63ba371 commit 8f87b86

File tree

5 files changed

+97
-70
lines changed

5 files changed

+97
-70
lines changed

pyo3-macros-backend/src/method.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ impl<'a> FnArg<'a> {
6363
}
6464
}
6565
}
66+
67+
pub fn is_regular(&self) -> bool {
68+
!self.py && !self.is_cancel_handle && !self.is_kwargs && !self.is_varargs
69+
}
6670
}
6771

6872
fn handle_argument_error(pat: &syn::Pat) -> syn::Error {

pyo3-macros-backend/src/params.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ pub fn is_forwarded_args(signature: &FunctionSignature<'_>) -> bool {
7373
)
7474
}
7575

76-
fn check_arg_for_gil_refs(
76+
pub(crate) fn check_arg_for_gil_refs(
7777
tokens: TokenStream,
7878
gil_refs_checker: syn::Ident,
7979
ctx: &Ctx,
@@ -120,7 +120,11 @@ pub fn impl_arg_params(
120120
.iter()
121121
.enumerate()
122122
.map(|(i, arg)| {
123-
impl_arg_param(arg, i, &mut 0, &args_array, holders, ctx).map(|tokens| {
123+
let from_py_with =
124+
syn::Ident::new(&format!("from_py_with_{}", i), Span::call_site());
125+
let arg_value = quote!(#args_array[0].as_deref());
126+
127+
impl_arg_param(arg, from_py_with, arg_value, holders, ctx).map(|tokens| {
124128
check_arg_for_gil_refs(
125129
tokens,
126130
holders.push_gil_refs_checker(arg.ty.span()),
@@ -161,14 +165,20 @@ pub fn impl_arg_params(
161165

162166
let num_params = positional_parameter_names.len() + keyword_only_parameters.len();
163167

164-
let mut option_pos = 0;
168+
let mut option_pos = 0usize;
165169
let param_conversion = spec
166170
.signature
167171
.arguments
168172
.iter()
169173
.enumerate()
170174
.map(|(i, arg)| {
171-
impl_arg_param(arg, i, &mut option_pos, &args_array, holders, ctx).map(|tokens| {
175+
let from_py_with = syn::Ident::new(&format!("from_py_with_{}", i), Span::call_site());
176+
let arg_value = quote!(#args_array[#option_pos].as_deref());
177+
if arg.is_regular() {
178+
option_pos += 1;
179+
}
180+
181+
impl_arg_param(arg, from_py_with, arg_value, holders, ctx).map(|tokens| {
172182
check_arg_for_gil_refs(tokens, holders.push_gil_refs_checker(arg.ty.span()), ctx)
173183
})
174184
})
@@ -234,11 +244,10 @@ pub fn impl_arg_params(
234244

235245
/// Re option_pos: The option slice doesn't contain the py: Python argument, so the argument
236246
/// index and the index in option diverge when using py: Python
237-
fn impl_arg_param(
247+
pub(crate) fn impl_arg_param(
238248
arg: &FnArg<'_>,
239-
pos: usize,
240-
option_pos: &mut usize,
241-
args_array: &syn::Ident,
249+
from_py_with: syn::Ident,
250+
arg_value: TokenStream, // expected type: Option<&'a Bound<'py, PyAny>>
242251
holders: &mut Holders,
243252
ctx: &Ctx,
244253
) -> Result<TokenStream> {
@@ -291,9 +300,6 @@ fn impl_arg_param(
291300
});
292301
}
293302

294-
let arg_value = quote_arg_span!(#args_array[#option_pos]);
295-
*option_pos += 1;
296-
297303
let mut default = arg.default.as_ref().map(|expr| quote!(#expr));
298304

299305
// Option<T> arguments have special treatment: the default should be specified _without_ the
@@ -312,11 +318,10 @@ fn impl_arg_param(
312318
.map(|attr| &attr.value)
313319
.is_some()
314320
{
315-
let from_py_with = syn::Ident::new(&format!("from_py_with_{}", pos), Span::call_site());
316321
if let Some(default) = default {
317322
quote_arg_span! {
318323
#pyo3_path::impl_::extract_argument::from_py_with_with_default(
319-
#arg_value.as_deref(),
324+
#arg_value,
320325
#name_str,
321326
#from_py_with as fn(_) -> _,
322327
#[allow(clippy::redundant_closure)]
@@ -328,7 +333,7 @@ fn impl_arg_param(
328333
} else {
329334
quote_arg_span! {
330335
#pyo3_path::impl_::extract_argument::from_py_with(
331-
&#pyo3_path::impl_::extract_argument::unwrap_required_argument(#arg_value),
336+
#pyo3_path::impl_::extract_argument::unwrap_required_argument(#arg_value),
332337
#name_str,
333338
#from_py_with as fn(_) -> _,
334339
)?
@@ -338,7 +343,7 @@ fn impl_arg_param(
338343
let holder = holders.push_holder(arg.name.span());
339344
quote_arg_span! {
340345
#pyo3_path::impl_::extract_argument::extract_optional_argument(
341-
#arg_value.as_deref(),
346+
#arg_value,
342347
&mut #holder,
343348
#name_str,
344349
#[allow(clippy::redundant_closure)]
@@ -351,7 +356,7 @@ fn impl_arg_param(
351356
let holder = holders.push_holder(arg.name.span());
352357
quote_arg_span! {
353358
#pyo3_path::impl_::extract_argument::extract_argument_with_default(
354-
#arg_value.as_deref(),
359+
#arg_value,
355360
&mut #holder,
356361
#name_str,
357362
#[allow(clippy::redundant_closure)]
@@ -364,7 +369,7 @@ fn impl_arg_param(
364369
let holder = holders.push_holder(arg.name.span());
365370
quote_arg_span! {
366371
#pyo3_path::impl_::extract_argument::extract_argument(
367-
&#pyo3_path::impl_::extract_argument::unwrap_required_argument(#arg_value),
372+
#pyo3_path::impl_::extract_argument::unwrap_required_argument(#arg_value),
368373
&mut #holder,
369374
#name_str
370375
)?

pyo3-macros-backend/src/pymethod.rs

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::borrow::Cow;
22

33
use crate::attributes::{NameAttribute, RenamingRule};
44
use crate::method::{CallingConvention, ExtractErrorMode};
5-
use crate::params::Holders;
5+
use crate::params::{check_arg_for_gil_refs, impl_arg_param, Holders};
66
use crate::utils::Ctx;
77
use crate::utils::PythonDoc;
88
use crate::{
@@ -586,48 +586,63 @@ pub fn impl_py_setter_def(
586586
}
587587
};
588588

589-
// TODO: rework this to make use of `impl_::params::impl_arg_param` which
590-
// handles all these cases already.
591-
let extract = if let PropertyType::Function { spec, .. } = &property_type {
592-
Some(spec)
593-
} else {
594-
None
595-
}
596-
.and_then(|spec| {
597-
let (_, args) = split_off_python_arg(&spec.signature.arguments);
598-
let value_arg = &args[0];
599-
let from_py_with = &value_arg.attrs.from_py_with.as_ref()?.value;
600-
let name = value_arg.name.to_string();
601-
602-
Some(quote_spanned! { from_py_with.span() =>
603-
let e = #pyo3_path::impl_::deprecations::GilRefs::new();
604-
let from_py_with = #pyo3_path::impl_::deprecations::inspect_fn(#from_py_with, &e);
605-
e.from_py_with_arg();
606-
let _val = #pyo3_path::impl_::extract_argument::from_py_with(
607-
&_value.into(),
608-
#name,
609-
from_py_with as fn(_) -> _,
610-
)?;
611-
})
612-
})
613-
.unwrap_or_else(|| {
614-
let (span, name) = match &property_type {
615-
PropertyType::Descriptor { field, .. } => (field.ty.span(), field.ident.as_ref().map(|i|i.to_string()).unwrap_or_default()),
616-
PropertyType::Function { spec, .. } => {
617-
let (_, args) = split_off_python_arg(&spec.signature.arguments);
618-
(args[0].ty.span(), args[0].name.to_string())
619-
}
620-
};
589+
let extract = match &property_type {
590+
PropertyType::Function { spec, .. } => {
591+
let (_, args) = split_off_python_arg(&spec.signature.arguments);
592+
let value_arg = &args[0];
593+
let (from_py_with, ident) = if let Some(from_py_with) =
594+
&value_arg.attrs.from_py_with.as_ref().map(|f| &f.value)
595+
{
596+
let ident = syn::Ident::new("from_py_with", from_py_with.span());
597+
(
598+
quote_spanned! { from_py_with.span() =>
599+
let e = #pyo3_path::impl_::deprecations::GilRefs::new();
600+
let #ident = #pyo3_path::impl_::deprecations::inspect_fn(#from_py_with, &e);
601+
e.from_py_with_arg();
602+
},
603+
ident,
604+
)
605+
} else {
606+
(quote!(), syn::Ident::new("dummy", Span::call_site()))
607+
};
621608

622-
let holder = holders.push_holder(span);
623-
let gil_refs_checker = holders.push_gil_refs_checker(span);
624-
quote! {
625-
let _val = #pyo3_path::impl_::deprecations::inspect_type(
626-
#pyo3_path::impl_::extract_argument::extract_argument(_value.into(), &mut #holder, #name)?,
627-
&#gil_refs_checker
628-
);
609+
let extract = impl_arg_param(
610+
&args[0],
611+
ident,
612+
quote!(::std::option::Option::Some(_value.into())),
613+
&mut holders,
614+
ctx,
615+
)
616+
.map(|tokens| {
617+
check_arg_for_gil_refs(
618+
tokens,
619+
holders.push_gil_refs_checker(value_arg.ty.span()),
620+
ctx,
621+
)
622+
})?;
623+
quote! {
624+
#from_py_with
625+
let _val = #extract;
626+
}
627+
}
628+
PropertyType::Descriptor { field, .. } => {
629+
let span = field.ty.span();
630+
let name = field
631+
.ident
632+
.as_ref()
633+
.map(|i| i.to_string())
634+
.unwrap_or_default();
635+
636+
let holder = holders.push_holder(span);
637+
let gil_refs_checker = holders.push_gil_refs_checker(span);
638+
quote! {
639+
let _val = #pyo3_path::impl_::deprecations::inspect_type(
640+
#pyo3_path::impl_::extract_argument::extract_argument(_value.into(), &mut #holder, #name)?,
641+
&#gil_refs_checker
642+
);
643+
}
629644
}
630-
});
645+
};
631646

632647
let mut cfg_attrs = TokenStream::new();
633648
if let PropertyType::Descriptor { field, .. } = &property_type {

src/impl_/extract_argument.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ pub fn argument_extraction_error(py: Python<'_>, arg_name: &str, error: PyErr) -
223223
/// `argument` must not be `None`
224224
#[doc(hidden)]
225225
#[inline]
226-
pub unsafe fn unwrap_required_argument(argument: Option<PyArg<'_>>) -> PyArg<'_> {
226+
pub unsafe fn unwrap_required_argument<'a, 'py>(
227+
argument: Option<&'a Bound<'py, PyAny>>,
228+
) -> &'a Bound<'py, PyAny> {
227229
match argument {
228230
Some(value) => value,
229231
#[cfg(debug_assertions)]

tests/ui/static_ref.stderr

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ error: lifetime may not live long enough
99
|
1010
= note: this error originates in the attribute macro `pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info)
1111

12+
error[E0597]: `output[_]` does not live long enough
13+
--> tests/ui/static_ref.rs:4:1
14+
|
15+
4 | #[pyfunction]
16+
| ^^^^^^^^^^^^-
17+
| | |
18+
| | `output[_]` dropped here while still borrowed
19+
| borrowed value does not live long enough
20+
| argument requires that `output[_]` is borrowed for `'static`
21+
|
22+
= note: this error originates in the attribute macro `pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info)
23+
1224
error[E0597]: `holder_0` does not live long enough
1325
--> tests/ui/static_ref.rs:5:15
1426
|
@@ -21,17 +33,6 @@ error[E0597]: `holder_0` does not live long enough
2133
5 | fn static_ref(list: &'static Bound<'_, PyList>) -> usize {
2234
| ^^^^^^^ borrowed value does not live long enough
2335

24-
error[E0716]: temporary value dropped while borrowed
25-
--> tests/ui/static_ref.rs:5:21
26-
|
27-
4 | #[pyfunction]
28-
| -------------
29-
| | |
30-
| | temporary value is freed at the end of this statement
31-
| argument requires that borrow lasts for `'static`
32-
5 | fn static_ref(list: &'static Bound<'_, PyList>) -> usize {
33-
| ^ creates a temporary value which is freed while still in use
34-
3536
error: lifetime may not live long enough
3637
--> tests/ui/static_ref.rs:9:1
3738
|

0 commit comments

Comments
 (0)