Skip to content

Commit 8ed0dab

Browse files
Add float type caster and revert type hint changes to int_ and float_ (#5839)
* Revert type hint changes to int_ and float_ These two types do not support casting from int-like and float-like types. * Fix tests * Add a custom py::float_ caster The default py::object caster only works if the object is an instance of the type. py::float_ should accept python int objects as well as float. This caster will pass through float as usual and cast int to float. The caster handles the type name so the custom one is not required. * style: pre-commit fixes * Fix name * Fix variable * Try satisfying the formatter * Rename test function * Simplify type caster * Fix reference counting issue --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 30748f8 commit 8ed0dab

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

include/pybind11/cast.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,7 +1401,7 @@ struct handle_type_name<buffer> {
14011401
};
14021402
template <>
14031403
struct handle_type_name<int_> {
1404-
static constexpr auto name = io_name("typing.SupportsInt", "int");
1404+
static constexpr auto name = const_name("int");
14051405
};
14061406
template <>
14071407
struct handle_type_name<iterable> {
@@ -1413,7 +1413,7 @@ struct handle_type_name<iterator> {
14131413
};
14141414
template <>
14151415
struct handle_type_name<float_> {
1416-
static constexpr auto name = io_name("typing.SupportsFloat", "float");
1416+
static constexpr auto name = const_name("float");
14171417
};
14181418
template <>
14191419
struct handle_type_name<function> {
@@ -1534,6 +1534,21 @@ struct pyobject_caster {
15341534
template <typename T>
15351535
class type_caster<T, enable_if_t<is_pyobject<T>::value>> : public pyobject_caster<T> {};
15361536

1537+
template <>
1538+
class type_caster<float_> : public pyobject_caster<float_> {
1539+
public:
1540+
bool load(handle src, bool /* convert */) {
1541+
if (isinstance<float_>(src)) {
1542+
value = reinterpret_borrow<float_>(src);
1543+
} else if (isinstance<int_>(src)) {
1544+
value = float_(reinterpret_borrow<int_>(src));
1545+
} else {
1546+
return false;
1547+
}
1548+
return true;
1549+
}
1550+
};
1551+
15371552
// Our conditions for enabling moving are quite restrictive:
15381553
// At compile time:
15391554
// - T needs to be a non-const, non-pointer, non-reference type

tests/test_pytypes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ TEST_SUBMODULE(pytypes, m) {
209209
m.def("get_tuple_from_iterable", [](const py::iterable &iter) { return py::tuple(iter); });
210210
// test_float
211211
m.def("get_float", [] { return py::float_(0.0f); });
212+
m.def("float_roundtrip", [](py::float_ f) { return f; });
212213
// test_list
213214
m.def("list_no_args", []() { return py::list{}; });
214215
m.def("list_ssize_t", []() { return py::list{(py::ssize_t) 0}; });

tests/test_pytypes.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ def test_iterable(doc):
6161

6262
def test_float(doc):
6363
assert doc(m.get_float) == "get_float() -> float"
64+
assert doc(m.float_roundtrip) == "float_roundtrip(arg0: float) -> float"
65+
f1 = m.float_roundtrip(5.5)
66+
assert isinstance(f1, float)
67+
assert f1 == 5.5
68+
f2 = m.float_roundtrip(5)
69+
assert isinstance(f2, float)
70+
assert f2 == 5.0
6471

6572

6673
def test_list(capture, doc):
@@ -917,7 +924,7 @@ def test_inplace_rshift(a, b):
917924
def test_tuple_nonempty_annotations(doc):
918925
assert (
919926
doc(m.annotate_tuple_float_str)
920-
== "annotate_tuple_float_str(arg0: tuple[typing.SupportsFloat, str]) -> None"
927+
== "annotate_tuple_float_str(arg0: tuple[float, str]) -> None"
921928
)
922929

923930

@@ -930,7 +937,7 @@ def test_tuple_empty_annotations(doc):
930937
def test_tuple_variable_length_annotations(doc):
931938
assert (
932939
doc(m.annotate_tuple_variable_length)
933-
== "annotate_tuple_variable_length(arg0: tuple[typing.SupportsFloat, ...]) -> None"
940+
== "annotate_tuple_variable_length(arg0: tuple[float, ...]) -> None"
934941
)
935942

936943

@@ -989,7 +996,7 @@ def test_type_annotation(doc):
989996
def test_union_annotations(doc):
990997
assert (
991998
doc(m.annotate_union)
992-
== "annotate_union(arg0: list[str | typing.SupportsInt | object], arg1: str, arg2: typing.SupportsInt, arg3: object) -> list[str | int | object]"
999+
== "annotate_union(arg0: list[str | int | object], arg1: str, arg2: int, arg3: object) -> list[str | int | object]"
9931000
)
9941001

9951002

0 commit comments

Comments
 (0)