Skip to content

Commit ee04df0

Browse files
authored
Updated STL casters and py::buffer to use collections.abc (#5566)
* Updated STL type hints use support collections.abc * Updated array_caster to match numpy/eigen typing.Annotated stlye * Added support for Mapping, Set and Sequence derived from collections.abc. * Fixed merge of typing.SupportsInt in new tests * Integrated collections.abc checks into convertible check functions. * Changed type hint of py::buffer to collections.abc.Buffer * Changed convertible check function names * Added comments to convertible check functions * Removed checks for methods that are already required by the abstract base class * Improved mapping caster test using more compact a1b2c3 variable * Renamed and refactored sequence, mapping and set test classes to reuse implementation * Added tests for mapping and set casters for noconvert mode * Added tests for sequence caster for noconvert mode
1 parent f3c1913 commit ee04df0

File tree

7 files changed

+237
-50
lines changed

7 files changed

+237
-50
lines changed

include/pybind11/cast.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1328,7 +1328,7 @@ struct handle_type_name<bytes> {
13281328
};
13291329
template <>
13301330
struct handle_type_name<buffer> {
1331-
static constexpr auto name = const_name("Buffer");
1331+
static constexpr auto name = const_name("collections.abc.Buffer");
13321332
};
13331333
template <>
13341334
struct handle_type_name<int_> {

include/pybind11/stl.h

+58-37
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ PYBIND11_NAMESPACE_BEGIN(detail)
4343
// Begin: Equivalent of
4444
// https://github.com/google/clif/blob/ae4eee1de07cdf115c0c9bf9fec9ff28efce6f6c/clif/python/runtime.cc#L388-L438
4545
/*
46-
The three `PyObjectTypeIsConvertibleTo*()` functions below are
46+
The three `object_is_convertible_to_*()` functions below are
4747
the result of converging the behaviors of pybind11 and PyCLIF
4848
(http://github.com/google/clif).
4949
@@ -69,10 +69,13 @@ to prevent accidents and improve readability:
6969
are also fairly commonly used, therefore enforcing explicit conversions
7070
would have an unfavorable cost : benefit ratio; more sloppily speaking,
7171
such an enforcement would be more annoying than helpful.
72+
73+
Additional checks have been added to allow types derived from `collections.abc.Set` and
74+
`collections.abc.Mapping` (`collections.abc.Sequence` is already allowed by `PySequence_Check`).
7275
*/
7376

74-
inline bool PyObjectIsInstanceWithOneOfTpNames(PyObject *obj,
75-
std::initializer_list<const char *> tp_names) {
77+
inline bool object_is_instance_with_one_of_tp_names(PyObject *obj,
78+
std::initializer_list<const char *> tp_names) {
7679
if (PyType_Check(obj)) {
7780
return false;
7881
}
@@ -85,37 +88,48 @@ inline bool PyObjectIsInstanceWithOneOfTpNames(PyObject *obj,
8588
return false;
8689
}
8790

88-
inline bool PyObjectTypeIsConvertibleToStdVector(PyObject *obj) {
89-
if (PySequence_Check(obj) != 0) {
90-
return !PyUnicode_Check(obj) && !PyBytes_Check(obj);
91+
inline bool object_is_convertible_to_std_vector(const handle &src) {
92+
// Allow sequence-like objects, but not (byte-)string-like objects.
93+
if (PySequence_Check(src.ptr()) != 0) {
94+
return !PyUnicode_Check(src.ptr()) && !PyBytes_Check(src.ptr());
9195
}
92-
return (PyGen_Check(obj) != 0) || (PyAnySet_Check(obj) != 0)
93-
|| PyObjectIsInstanceWithOneOfTpNames(
94-
obj, {"dict_keys", "dict_values", "dict_items", "map", "zip"});
96+
// Allow generators, set/frozenset and several common iterable types.
97+
return (PyGen_Check(src.ptr()) != 0) || (PyAnySet_Check(src.ptr()) != 0)
98+
|| object_is_instance_with_one_of_tp_names(
99+
src.ptr(), {"dict_keys", "dict_values", "dict_items", "map", "zip"});
95100
}
96101

97-
inline bool PyObjectTypeIsConvertibleToStdSet(PyObject *obj) {
98-
return (PyAnySet_Check(obj) != 0) || PyObjectIsInstanceWithOneOfTpNames(obj, {"dict_keys"});
102+
inline bool object_is_convertible_to_std_set(const handle &src, bool convert) {
103+
// Allow set/frozenset and dict keys.
104+
// In convert mode: also allow types derived from collections.abc.Set.
105+
return ((PyAnySet_Check(src.ptr()) != 0)
106+
|| object_is_instance_with_one_of_tp_names(src.ptr(), {"dict_keys"}))
107+
|| (convert && isinstance(src, module_::import("collections.abc").attr("Set")));
99108
}
100109

101-
inline bool PyObjectTypeIsConvertibleToStdMap(PyObject *obj) {
102-
if (PyDict_Check(obj)) {
110+
inline bool object_is_convertible_to_std_map(const handle &src, bool convert) {
111+
// Allow dict.
112+
if (PyDict_Check(src.ptr())) {
103113
return true;
104114
}
105-
// Implicit requirement in the conditions below:
106-
// A type with `.__getitem__()` & `.items()` methods must implement these
107-
// to be compatible with https://docs.python.org/3/c-api/mapping.html
108-
if (PyMapping_Check(obj) == 0) {
109-
return false;
110-
}
111-
PyObject *items = PyObject_GetAttrString(obj, "items");
112-
if (items == nullptr) {
113-
PyErr_Clear();
114-
return false;
115+
// Allow types conforming to Mapping Protocol.
116+
// According to https://docs.python.org/3/c-api/mapping.html, `PyMappingCheck()` checks for
117+
// `__getitem__()` without checking the type of keys. In order to restrict the allowed types
118+
// closer to actual Mapping-like types, we also check for the `items()` method.
119+
if (PyMapping_Check(src.ptr()) != 0) {
120+
PyObject *items = PyObject_GetAttrString(src.ptr(), "items");
121+
if (items != nullptr) {
122+
bool is_convertible = (PyCallable_Check(items) != 0);
123+
Py_DECREF(items);
124+
if (is_convertible) {
125+
return true;
126+
}
127+
} else {
128+
PyErr_Clear();
129+
}
115130
}
116-
bool is_convertible = (PyCallable_Check(items) != 0);
117-
Py_DECREF(items);
118-
return is_convertible;
131+
// In convert mode: Allow types derived from collections.abc.Mapping
132+
return convert && isinstance(src, module_::import("collections.abc").attr("Mapping"));
119133
}
120134

121135
//
@@ -172,7 +186,7 @@ struct set_caster {
172186

173187
public:
174188
bool load(handle src, bool convert) {
175-
if (!PyObjectTypeIsConvertibleToStdSet(src.ptr())) {
189+
if (!object_is_convertible_to_std_set(src, convert)) {
176190
return false;
177191
}
178192
if (isinstance<anyset>(src)) {
@@ -203,7 +217,9 @@ struct set_caster {
203217
return s.release();
204218
}
205219

206-
PYBIND11_TYPE_CASTER(type, const_name("set[") + key_conv::name + const_name("]"));
220+
PYBIND11_TYPE_CASTER(type,
221+
io_name("collections.abc.Set", "set") + const_name("[") + key_conv::name
222+
+ const_name("]"));
207223
};
208224

209225
template <typename Type, typename Key, typename Value>
@@ -234,7 +250,7 @@ struct map_caster {
234250

235251
public:
236252
bool load(handle src, bool convert) {
237-
if (!PyObjectTypeIsConvertibleToStdMap(src.ptr())) {
253+
if (!object_is_convertible_to_std_map(src, convert)) {
238254
return false;
239255
}
240256
if (isinstance<dict>(src)) {
@@ -274,7 +290,8 @@ struct map_caster {
274290
}
275291

276292
PYBIND11_TYPE_CASTER(Type,
277-
const_name("dict[") + key_conv::name + const_name(", ") + value_conv::name
293+
io_name("collections.abc.Mapping", "dict") + const_name("[")
294+
+ key_conv::name + const_name(", ") + value_conv::name
278295
+ const_name("]"));
279296
};
280297

@@ -283,7 +300,7 @@ struct list_caster {
283300
using value_conv = make_caster<Value>;
284301

285302
bool load(handle src, bool convert) {
286-
if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) {
303+
if (!object_is_convertible_to_std_vector(src)) {
287304
return false;
288305
}
289306
if (isinstance<sequence>(src)) {
@@ -340,7 +357,9 @@ struct list_caster {
340357
return l.release();
341358
}
342359

343-
PYBIND11_TYPE_CASTER(Type, const_name("list[") + value_conv::name + const_name("]"));
360+
PYBIND11_TYPE_CASTER(Type,
361+
io_name("collections.abc.Sequence", "list") + const_name("[")
362+
+ value_conv::name + const_name("]"));
344363
};
345364

346365
template <typename Type, typename Alloc>
@@ -416,7 +435,7 @@ struct array_caster {
416435

417436
public:
418437
bool load(handle src, bool convert) {
419-
if (!PyObjectTypeIsConvertibleToStdVector(src.ptr())) {
438+
if (!object_is_convertible_to_std_vector(src)) {
420439
return false;
421440
}
422441
if (isinstance<sequence>(src)) {
@@ -474,10 +493,12 @@ struct array_caster {
474493
using cast_op_type = movable_cast_op_type<T_>;
475494

476495
static constexpr auto name
477-
= const_name<Resizable>(const_name(""), const_name("Annotated[")) + const_name("list[")
478-
+ value_conv::name + const_name("]")
479-
+ const_name<Resizable>(
480-
const_name(""), const_name(", FixedSize(") + const_name<Size>() + const_name(")]"));
496+
= const_name<Resizable>(const_name(""), const_name("typing.Annotated["))
497+
+ io_name("collections.abc.Sequence", "list") + const_name("[") + value_conv::name
498+
+ const_name("]")
499+
+ const_name<Resizable>(const_name(""),
500+
const_name(", \"FixedSize(") + const_name<Size>()
501+
+ const_name(")\"]"));
481502
};
482503

483504
template <typename Type, size_t Size>

tests/test_buffers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_ctypes_from_buffer():
230230
def test_buffer_docstring():
231231
assert (
232232
m.get_buffer_info.__doc__.strip()
233-
== "get_buffer_info(arg0: Buffer) -> pybind11_tests.buffers.buffer_info"
233+
== "get_buffer_info(arg0: collections.abc.Buffer) -> pybind11_tests.buffers.buffer_info"
234234
)
235235

236236

tests/test_kwargs_and_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_function_signatures(doc):
2222
assert doc(m.kw_func3) == "kw_func3(data: str = 'Hello world!') -> None"
2323
assert (
2424
doc(m.kw_func4)
25-
== "kw_func4(myList: list[typing.SupportsInt] = [13, 17]) -> str"
25+
== "kw_func4(myList: collections.abc.Sequence[typing.SupportsInt] = [13, 17]) -> str"
2626
)
2727
assert (
2828
doc(m.kw_func_udl)

tests/test_pytypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ def test_arg_return_type_hints(doc):
12541254
# std::vector<T>
12551255
assert (
12561256
doc(m.half_of_number_vector)
1257-
== "half_of_number_vector(arg0: list[Union[float, int]]) -> list[float]"
1257+
== "half_of_number_vector(arg0: collections.abc.Sequence[Union[float, int]]) -> list[float]"
12581258
)
12591259
# Tuple<T, T>
12601260
assert (

tests/test_stl.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -648,4 +648,19 @@ TEST_SUBMODULE(stl, m) {
648648
}
649649
return zum;
650650
});
651+
m.def("roundtrip_std_vector_int", [](const std::vector<int> &v) { return v; });
652+
m.def("roundtrip_std_map_str_int", [](const std::map<std::string, int> &m) { return m; });
653+
m.def("roundtrip_std_set_int", [](const std::set<int> &s) { return s; });
654+
m.def(
655+
"roundtrip_std_vector_int_noconvert",
656+
[](const std::vector<int> &v) { return v; },
657+
py::arg("v").noconvert());
658+
m.def(
659+
"roundtrip_std_map_str_int_noconvert",
660+
[](const std::map<std::string, int> &m) { return m; },
661+
py::arg("m").noconvert());
662+
m.def(
663+
"roundtrip_std_set_int_noconvert",
664+
[](const std::set<int> &s) { return s; },
665+
py::arg("s").noconvert());
651666
}

0 commit comments

Comments
 (0)