Skip to content

Commit 756d055

Browse files
committed
Updated STL type hints use support collections.abc
1 parent d28904f commit 756d055

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

include/pybind11/stl.h

+11-5
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ struct set_caster {
203203
return s.release();
204204
}
205205

206-
PYBIND11_TYPE_CASTER(type, const_name("set[") + key_conv::name + const_name("]"));
206+
PYBIND11_TYPE_CASTER(type,
207+
io_name("collections.abc.Set", "set") + const_name("[") + key_conv::name
208+
+ const_name("]"));
207209
};
208210

209211
template <typename Type, typename Key, typename Value>
@@ -274,7 +276,8 @@ struct map_caster {
274276
}
275277

276278
PYBIND11_TYPE_CASTER(Type,
277-
const_name("dict[") + key_conv::name + const_name(", ") + value_conv::name
279+
io_name("collections.abc.Mapping", "dict") + const_name("[")
280+
+ key_conv::name + const_name(", ") + value_conv::name
278281
+ const_name("]"));
279282
};
280283

@@ -340,7 +343,9 @@ struct list_caster {
340343
return l.release();
341344
}
342345

343-
PYBIND11_TYPE_CASTER(Type, const_name("list[") + value_conv::name + const_name("]"));
346+
PYBIND11_TYPE_CASTER(Type,
347+
io_name("collections.abc.Sequence", "list") + const_name("[")
348+
+ value_conv::name + const_name("]"));
344349
};
345350

346351
template <typename Type, typename Alloc>
@@ -474,8 +479,9 @@ struct array_caster {
474479
using cast_op_type = movable_cast_op_type<T_>;
475480

476481
static constexpr auto name
477-
= const_name<Resizable>(const_name(""), const_name("Annotated[")) + const_name("list[")
478-
+ value_conv::name + const_name("]")
482+
= const_name<Resizable>(const_name(""), const_name("Annotated["))
483+
+ io_name("collections.abc.Sequence", "list") + const_name("[") + value_conv::name
484+
+ const_name("]")
479485
+ const_name<Resizable>(
480486
const_name(""), const_name(", FixedSize(") + const_name<Size>() + const_name(")]"));
481487
};

tests/test_kwargs_and_defaults.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ def test_function_signatures(doc):
1111
assert doc(m.kw_func1) == "kw_func1(x: int, y: int) -> str"
1212
assert doc(m.kw_func2) == "kw_func2(x: int = 100, y: int = 200) -> str"
1313
assert doc(m.kw_func3) == "kw_func3(data: str = 'Hello world!') -> None"
14-
assert doc(m.kw_func4) == "kw_func4(myList: list[int] = [13, 17]) -> str"
14+
assert (
15+
doc(m.kw_func4)
16+
== "kw_func4(myList: collections.abc.Sequence[int] = [13, 17]) -> str"
17+
)
1518
assert doc(m.kw_func_udl) == "kw_func_udl(x: int, y: int = 300) -> str"
1619
assert doc(m.kw_func_udl_z) == "kw_func_udl_z(x: int, y: int = 0) -> str"
1720
assert doc(m.args_function) == "args_function(*args) -> tuple"

tests/test_pytypes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ def test_arg_return_type_hints(doc):
12431243
# std::vector<T>
12441244
assert (
12451245
doc(m.half_of_number_vector)
1246-
== "half_of_number_vector(arg0: list[Union[float, int]]) -> list[float]"
1246+
== "half_of_number_vector(arg0: collections.abc.Sequence[Union[float, int]]) -> list[float]"
12471247
)
12481248
# Tuple<T, T>
12491249
assert (

tests/test_stl.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def test_vector(doc):
2020
assert m.load_bool_vector((True, False))
2121

2222
assert doc(m.cast_vector) == "cast_vector() -> list[int]"
23-
assert doc(m.load_vector) == "load_vector(arg0: list[int]) -> bool"
23+
assert (
24+
doc(m.load_vector) == "load_vector(arg0: collections.abc.Sequence[int]) -> bool"
25+
)
2426

2527
# Test regression caused by 936: pointers to stl containers weren't castable
2628
assert m.cast_ptr_vector() == ["lvalue", "lvalue"]
@@ -45,7 +47,7 @@ def test_array(doc):
4547
assert doc(m.cast_array) == "cast_array() -> Annotated[list[int], FixedSize(2)]"
4648
assert (
4749
doc(m.load_array)
48-
== "load_array(arg0: Annotated[list[int], FixedSize(2)]) -> bool"
50+
== "load_array(arg0: Annotated[collections.abc.Sequence[int], FixedSize(2)]) -> bool"
4951
)
5052

5153

@@ -64,7 +66,10 @@ def test_valarray(doc):
6466
assert m.load_valarray(tuple(lst))
6567

6668
assert doc(m.cast_valarray) == "cast_valarray() -> list[int]"
67-
assert doc(m.load_valarray) == "load_valarray(arg0: list[int]) -> bool"
69+
assert (
70+
doc(m.load_valarray)
71+
== "load_valarray(arg0: collections.abc.Sequence[int]) -> bool"
72+
)
6873

6974

7075
def test_map(doc):
@@ -77,7 +82,9 @@ def test_map(doc):
7782
assert m.load_map(d)
7883

7984
assert doc(m.cast_map) == "cast_map() -> dict[str, str]"
80-
assert doc(m.load_map) == "load_map(arg0: dict[str, str]) -> bool"
85+
assert (
86+
doc(m.load_map) == "load_map(arg0: collections.abc.Mapping[str, str]) -> bool"
87+
)
8188

8289

8390
def test_set(doc):
@@ -89,7 +96,7 @@ def test_set(doc):
8996
assert m.load_set(frozenset(s))
9097

9198
assert doc(m.cast_set) == "cast_set() -> set[str]"
92-
assert doc(m.load_set) == "load_set(arg0: set[str]) -> bool"
99+
assert doc(m.load_set) == "load_set(arg0: collections.abc.Set[str]) -> bool"
93100

94101

95102
def test_recursive_casting():
@@ -271,7 +278,7 @@ def __fspath__(self):
271278
assert m.parent_paths(["foo/bar", "foo/baz"]) == [Path("foo"), Path("foo")]
272279
assert (
273280
doc(m.parent_paths)
274-
== "parent_paths(arg0: list[Union[os.PathLike, str, bytes]]) -> list[pathlib.Path]"
281+
== "parent_paths(arg0: collections.abc.Sequence[Union[os.PathLike, str, bytes]]) -> list[pathlib.Path]"
275282
)
276283
# py::typing::List
277284
assert m.parent_paths_list(["foo/bar", "foo/baz"]) == [Path("foo"), Path("foo")]
@@ -361,7 +368,7 @@ def test_stl_pass_by_pointer(msg):
361368
msg(excinfo.value)
362369
== """
363370
stl_pass_by_pointer(): incompatible function arguments. The following argument types are supported:
364-
1. (v: list[int] = None) -> list[int]
371+
1. (v: collections.abc.Sequence[int] = None) -> list[int]
365372
366373
Invoked with:
367374
"""
@@ -373,7 +380,7 @@ def test_stl_pass_by_pointer(msg):
373380
msg(excinfo.value)
374381
== """
375382
stl_pass_by_pointer(): incompatible function arguments. The following argument types are supported:
376-
1. (v: list[int] = None) -> list[int]
383+
1. (v: collections.abc.Sequence[int] = None) -> list[int]
377384
378385
Invoked with: None
379386
"""

0 commit comments

Comments
 (0)