Skip to content

Commit 7f94f24

Browse files
XuehaiPanrwgk
andauthored
feat(typing): allow annotate methods with pos_only when only have the self argument (#5403)
* feat: allow annotate methods with `pos_only` when only have the `self` argument * chore(typing): make arguments for auto-generated dunder methods positional-only * docs: add more comments to improve readability * style: fix nit suggestions * Add test_self_only_pos_only() in tests/test_methods_and_attributes * test: add docstring tests for generated dunder methods * test: remove failed tests * fix(test): run `gc.collect()` three times for refcount tests --------- Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]>
1 parent 6d98d4d commit 7f94f24

File tree

8 files changed

+155
-25
lines changed

8 files changed

+155
-25
lines changed

include/pybind11/detail/init.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ struct pickle_factory<Get, Set, RetState(Self), NewInstance(ArgState)> {
410410

411411
template <typename Class, typename... Extra>
412412
void execute(Class &cl, const Extra &...extra) && {
413-
cl.def("__getstate__", std::move(get));
413+
cl.def("__getstate__", std::move(get), pos_only());
414414

415415
#if defined(PYBIND11_CPP14)
416416
cl.def(

include/pybind11/pybind11.h

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,20 @@ class cpp_function : public function {
301301
constexpr bool has_kw_only_args = any_of<std::is_same<kw_only, Extra>...>::value,
302302
has_pos_only_args = any_of<std::is_same<pos_only, Extra>...>::value,
303303
has_arg_annotations = any_of<is_keyword<Extra>...>::value;
304+
constexpr bool has_is_method = any_of<std::is_same<is_method, Extra>...>::value;
305+
// The implicit `self` argument is not present and not counted in method definitions.
306+
constexpr bool has_args = cast_in::args_pos >= 0;
307+
constexpr bool is_method_with_self_arg_only = has_is_method && !has_args;
304308
static_assert(has_arg_annotations || !has_kw_only_args,
305309
"py::kw_only requires the use of argument annotations");
306-
static_assert(has_arg_annotations || !has_pos_only_args,
310+
static_assert(((/* Need `py::arg("arg_name")` annotation in function/method. */
311+
has_arg_annotations)
312+
|| (/* Allow methods with no arguments `def method(self, /): ...`.
313+
* A method has at least one argument `self`. There can be no
314+
* `py::arg` annotation. E.g. `class.def("method", py::pos_only())`.
315+
*/
316+
is_method_with_self_arg_only))
317+
|| !has_pos_only_args,
307318
"py::pos_only requires the use of argument annotations (for docstrings "
308319
"and aligning the annotations to the argument)");
309320

@@ -2022,17 +2033,20 @@ struct enum_base {
20222033
.format(std::move(type_name), enum_name(arg), int_(arg));
20232034
},
20242035
name("__repr__"),
2025-
is_method(m_base));
2036+
is_method(m_base),
2037+
pos_only());
20262038

2027-
m_base.attr("name") = property(cpp_function(&enum_name, name("name"), is_method(m_base)));
2039+
m_base.attr("name")
2040+
= property(cpp_function(&enum_name, name("name"), is_method(m_base), pos_only()));
20282041

20292042
m_base.attr("__str__") = cpp_function(
20302043
[](handle arg) -> str {
20312044
object type_name = type::handle_of(arg).attr("__name__");
20322045
return pybind11::str("{}.{}").format(std::move(type_name), enum_name(arg));
20332046
},
20342047
name("__str__"),
2035-
is_method(m_base));
2048+
is_method(m_base),
2049+
pos_only());
20362050

20372051
if (options::show_enum_members_docstring()) {
20382052
m_base.attr("__doc__") = static_property(
@@ -2087,7 +2101,8 @@ struct enum_base {
20872101
}, \
20882102
name(op), \
20892103
is_method(m_base), \
2090-
arg("other"))
2104+
arg("other"), \
2105+
pos_only())
20912106

20922107
#define PYBIND11_ENUM_OP_CONV(op, expr) \
20932108
m_base.attr(op) = cpp_function( \
@@ -2097,7 +2112,8 @@ struct enum_base {
20972112
}, \
20982113
name(op), \
20992114
is_method(m_base), \
2100-
arg("other"))
2115+
arg("other"), \
2116+
pos_only())
21012117

21022118
#define PYBIND11_ENUM_OP_CONV_LHS(op, expr) \
21032119
m_base.attr(op) = cpp_function( \
@@ -2107,7 +2123,8 @@ struct enum_base {
21072123
}, \
21082124
name(op), \
21092125
is_method(m_base), \
2110-
arg("other"))
2126+
arg("other"), \
2127+
pos_only())
21112128

21122129
if (is_convertible) {
21132130
PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b));
@@ -2127,7 +2144,8 @@ struct enum_base {
21272144
m_base.attr("__invert__")
21282145
= cpp_function([](const object &arg) { return ~(int_(arg)); },
21292146
name("__invert__"),
2130-
is_method(m_base));
2147+
is_method(m_base),
2148+
pos_only());
21312149
}
21322150
} else {
21332151
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false);
@@ -2147,11 +2165,15 @@ struct enum_base {
21472165
#undef PYBIND11_ENUM_OP_CONV
21482166
#undef PYBIND11_ENUM_OP_STRICT
21492167

2150-
m_base.attr("__getstate__") = cpp_function(
2151-
[](const object &arg) { return int_(arg); }, name("__getstate__"), is_method(m_base));
2168+
m_base.attr("__getstate__") = cpp_function([](const object &arg) { return int_(arg); },
2169+
name("__getstate__"),
2170+
is_method(m_base),
2171+
pos_only());
21522172

2153-
m_base.attr("__hash__") = cpp_function(
2154-
[](const object &arg) { return int_(arg); }, name("__hash__"), is_method(m_base));
2173+
m_base.attr("__hash__") = cpp_function([](const object &arg) { return int_(arg); },
2174+
name("__hash__"),
2175+
is_method(m_base),
2176+
pos_only());
21552177
}
21562178

21572179
PYBIND11_NOINLINE void value(char const *name_, object value, const char *doc = nullptr) {
@@ -2243,9 +2265,9 @@ class enum_ : public class_<Type> {
22432265
m_base.init(is_arithmetic, is_convertible);
22442266

22452267
def(init([](Scalar i) { return static_cast<Type>(i); }), arg("value"));
2246-
def_property_readonly("value", [](Type value) { return (Scalar) value; });
2247-
def("__int__", [](Type value) { return (Scalar) value; });
2248-
def("__index__", [](Type value) { return (Scalar) value; });
2268+
def_property_readonly("value", [](Type value) { return (Scalar) value; }, pos_only());
2269+
def("__int__", [](Type value) { return (Scalar) value; }, pos_only());
2270+
def("__index__", [](Type value) { return (Scalar) value; }, pos_only());
22492271
attr("__setstate__") = cpp_function(
22502272
[](detail::value_and_holder &v_h, Scalar arg) {
22512273
detail::initimpl::setstate<Base>(
@@ -2254,7 +2276,8 @@ class enum_ : public class_<Type> {
22542276
detail::is_new_style_constructor(),
22552277
pybind11::name("__setstate__"),
22562278
is_method(*this),
2257-
arg("state"));
2279+
arg("state"),
2280+
pos_only());
22582281
}
22592282

22602283
/// Export enumeration entries into the parent scope
@@ -2440,7 +2463,8 @@ iterator make_iterator_impl(Iterator first, Sentinel last, Extra &&...extra) {
24402463

24412464
if (!detail::get_type_info(typeid(state), false)) {
24422465
class_<state>(handle(), "iterator", pybind11::module_local())
2443-
.def("__iter__", [](state &s) -> state & { return s; })
2466+
.def(
2467+
"__iter__", [](state &s) -> state & { return s; }, pos_only())
24442468
.def(
24452469
"__next__",
24462470
[](state &s) -> ValueType {
@@ -2457,6 +2481,7 @@ iterator make_iterator_impl(Iterator first, Sentinel last, Extra &&...extra) {
24572481
// NOLINTNEXTLINE(readability-const-return-type) // PR #3263
24582482
},
24592483
std::forward<Extra>(extra)...,
2484+
pos_only(),
24602485
Policy);
24612486
}
24622487

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,11 @@ def pytest_assertrepr_compare(op, left, right): # noqa: ARG001
198198

199199

200200
def gc_collect():
201-
"""Run the garbage collector twice (needed when running
201+
"""Run the garbage collector three times (needed when running
202202
reference counting tests with PyPy)"""
203203
gc.collect()
204204
gc.collect()
205+
gc.collect()
205206

206207

207208
def pytest_configure():

tests/test_enum.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# ruff: noqa: SIM201 SIM300 SIM202
22
from __future__ import annotations
33

4+
import re
5+
46
import pytest
57

68
import env # noqa: F401
@@ -271,3 +273,61 @@ def test_docstring_signatures():
271273
def test_str_signature():
272274
for enum_type in [m.ScopedEnum, m.UnscopedEnum]:
273275
assert enum_type.__str__.__doc__.startswith("__str__")
276+
277+
278+
def test_generated_dunder_methods_pos_only():
279+
for enum_type in [m.ScopedEnum, m.UnscopedEnum]:
280+
for binary_op in [
281+
"__eq__",
282+
"__ne__",
283+
"__ge__",
284+
"__gt__",
285+
"__lt__",
286+
"__le__",
287+
"__and__",
288+
"__rand__",
289+
# "__or__", # fail with some compilers (__doc__ = "Return self|value.")
290+
# "__ror__", # fail with some compilers (__doc__ = "Return value|self.")
291+
"__xor__",
292+
"__rxor__",
293+
"__rxor__",
294+
]:
295+
method = getattr(enum_type, binary_op, None)
296+
if method is not None:
297+
assert (
298+
re.match(
299+
rf"^{binary_op}\(self: [\w\.]+, other: [\w\.]+, /\)",
300+
method.__doc__,
301+
)
302+
is not None
303+
)
304+
for unary_op in [
305+
"__int__",
306+
"__index__",
307+
"__hash__",
308+
"__str__",
309+
"__repr__",
310+
]:
311+
method = getattr(enum_type, unary_op, None)
312+
if method is not None:
313+
assert (
314+
re.match(
315+
rf"^{unary_op}\(self: [\w\.]+, /\)",
316+
method.__doc__,
317+
)
318+
is not None
319+
)
320+
assert (
321+
re.match(
322+
r"^__getstate__\(self: [\w\.]+, /\)",
323+
enum_type.__getstate__.__doc__,
324+
)
325+
is not None
326+
)
327+
assert (
328+
re.match(
329+
r"^__setstate__\(self: [\w\.]+, state: [\w\.]+, /\)",
330+
enum_type.__setstate__.__doc__,
331+
)
332+
is not None
333+
)

tests/test_methods_and_attributes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ TEST_SUBMODULE(methods_and_attributes, m) {
294294
static_cast<py::str (ExampleMandA::*)(int, int)>(
295295
&ExampleMandA::overloaded));
296296
})
297-
.def("__str__", &ExampleMandA::toString)
297+
.def("__str__", &ExampleMandA::toString, py::pos_only())
298298
.def_readwrite("value", &ExampleMandA::value);
299299

300300
// test_copy_method

tests/test_methods_and_attributes.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
)
2020

2121

22+
def test_self_only_pos_only():
23+
assert (
24+
m.ExampleMandA.__str__.__doc__
25+
== "__str__(self: pybind11_tests.methods_and_attributes.ExampleMandA, /) -> str\n"
26+
)
27+
28+
2229
def test_methods_and_attributes():
2330
instance1 = m.ExampleMandA()
2431
instance2 = m.ExampleMandA(32)

tests/test_pickling.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,20 @@ def test_roundtrip_simple_cpp_derived():
9393
# Issue #3062: pickleable base C++ classes can incur object slicing
9494
# if derived typeid is not registered with pybind11
9595
assert not m.check_dynamic_cast_SimpleCppDerived(p2)
96+
97+
98+
def test_new_style_pickle_getstate_pos_only():
99+
assert (
100+
re.match(
101+
r"^__getstate__\(self: [\w\.]+, /\)", m.PickleableNew.__getstate__.__doc__
102+
)
103+
is not None
104+
)
105+
if hasattr(m, "PickleableWithDictNew"):
106+
assert (
107+
re.match(
108+
r"^__getstate__\(self: [\w\.]+, /\)",
109+
m.PickleableWithDictNew.__getstate__.__doc__,
110+
)
111+
is not None
112+
)

tests/test_sequences_and_iterators.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import re
4+
35
import pytest
46
from pytest import approx # noqa: PT013
57

@@ -253,16 +255,12 @@ def bad_next_call():
253255

254256
def test_iterator_passthrough():
255257
"""#181: iterator passthrough did not compile"""
256-
from pybind11_tests.sequences_and_iterators import iterator_passthrough
257-
258258
values = [3, 5, 7, 9, 11, 13, 15]
259-
assert list(iterator_passthrough(iter(values))) == values
259+
assert list(m.iterator_passthrough(iter(values))) == values
260260

261261

262262
def test_iterator_rvp():
263263
"""#388: Can't make iterators via make_iterator() with different r/v policies"""
264-
import pybind11_tests.sequences_and_iterators as m
265-
266264
assert list(m.make_iterator_1()) == [1, 2, 3]
267265
assert list(m.make_iterator_2()) == [1, 2, 3]
268266
assert not isinstance(m.make_iterator_1(), type(m.make_iterator_2()))
@@ -274,3 +272,25 @@ def test_carray_iterator():
274272
arr_h = m.CArrayHolder(*args_gt)
275273
args = list(arr_h)
276274
assert args_gt == args
275+
276+
277+
def test_generated_dunder_methods_pos_only():
278+
string_map = m.StringMap({"hi": "bye", "black": "white"})
279+
for it in (
280+
m.make_iterator_1(),
281+
m.make_iterator_2(),
282+
m.iterator_passthrough(iter([3, 5, 7])),
283+
iter(m.Sequence(5)),
284+
iter(string_map),
285+
string_map.items(),
286+
string_map.values(),
287+
iter(m.CArrayHolder(*[float(i) for i in range(3)])),
288+
):
289+
assert (
290+
re.match(r"^__iter__\(self: [\w\.]+, /\)", type(it).__iter__.__doc__)
291+
is not None
292+
)
293+
assert (
294+
re.match(r"^__next__\(self: [\w\.]+, /\)", type(it).__next__.__doc__)
295+
is not None
296+
)

0 commit comments

Comments
 (0)