Skip to content

Commit 8c25375

Browse files
committed
Add upcast_hook for exposing non-primary base relationships
1 parent b4b9331 commit 8c25375

File tree

7 files changed

+132
-8
lines changed

7 files changed

+132
-8
lines changed

docs/api_core.rst

+38
Original file line numberDiff line numberDiff line change
@@ -2086,6 +2086,44 @@ declarations in generated :ref:`stubs <stubs>`,
20862086
Declares a callback that will be invoked when a C++ instance is first
20872087
cast into a Python object.
20882088

2089+
.. cpp:struct:: upcast_hook
2090+
2091+
.. cpp:function:: upcast_hook(void* (* hook)(PyObject*, const std::type_info*) noexcept)
2092+
2093+
Allow Python instances of the class being bound to be passed to C++
2094+
functions that expect a pointer to a subobject of that class.
2095+
Since nanobind only acknowledges at most one base class of each bound type,
2096+
the upcast hook can be helpful for providing some minimal emulation of
2097+
additional bases.
2098+
2099+
The hook receives a nanobind instance as its first argument and the
2100+
desired subobject type as its second. If it can make the cast, it
2101+
returns a pointer to something of the requested type; if not, it
2102+
returns nullptr.
2103+
2104+
**Example:**
2105+
2106+
.. code-block:: cpp
2107+
2108+
struct A { int a = 10; };
2109+
struct B { int b = 20; };
2110+
struct D : A, B { int d = 30; };
2111+
2112+
nb::class_<A>(m, "A").def_rw("a", &A::a);
2113+
auto clsB = nb::class_<B>(m, "B").def_rw("b", &B::b);
2114+
2115+
auto try_D_to_B = [](PyObject *self_py, const std::type_info *target) noexcept -> void* {
2116+
D *self = nb::inst_ptr<D>(self_py);
2117+
if (*target == &typeid(B)) {
2118+
return static_cast<B*>(self);
2119+
}
2120+
return nullptr;
2121+
};
2122+
2123+
auto clsD = nb::class_<D, A>(m, "D", nb::upcast_hook(try_D_to_B))
2124+
.def_rw("d", &D::d);
2125+
clsD.attr("b") = clsB.attr("b");
2126+
20892127
20902128
.. _enum_binding_annotations:
20912129

docs/changelog.rst

+6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ Upcoming version (TBA)
2525
long-standing inconvenience. (PR `#778
2626
<https://github.com/wjakob/nanobind/pull/778>`__).
2727

28+
- Added the class binding annotation :cpp:class:`nb::upcast_hook()
29+
<upcast_hook>` which allows the bound type to describe how to
30+
extract self-pointers of other types from its instances. This can
31+
be useful as part of a strategy for mimicking multiple inheritance.
32+
(PR `#920 <https://github.com/wjakob/nanobind/pull/920>`__)
33+
2834
* ABI version 16.
2935

3036

include/nanobind/nb_attr.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,10 @@ struct type_slots {
137137
const PyType_Slot *value;
138138
};
139139

140-
struct type_slots_callback {
141-
using cb_t = void (*)(const detail::type_init_data *t,
142-
PyType_Slot *&slots, size_t max_slots) noexcept;
143-
type_slots_callback(cb_t callback) : callback(callback) { }
144-
cb_t callback;
140+
struct upcast_hook {
141+
using cb_t = void* (*)(PyObject *, const std::type_info *) noexcept;
142+
upcast_hook(cb_t hook) : hook(hook) { }
143+
cb_t hook;
145144
};
146145

147146
struct sig {

include/nanobind/nb_class.h

+12-2
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@ enum class type_flags : uint32_t {
6363

6464
/// Does the type implement a custom __new__ operator that can take no args
6565
/// (except the type object)?
66-
has_nullary_new = (1 << 17)
66+
has_nullary_new = (1 << 17),
6767

68-
// One more bit available without needing a larger reorganization
68+
/// Does the type provide a upcast_hook?
69+
has_upcast_hook = (1 << 18)
70+
71+
// Reorganization will be needed to add any more flags;
72+
// try splitting type_init_flags into a separate field in type_init_data
6973
};
7074

7175
/// Flags about a type that are only relevant when it is being created.
@@ -125,6 +129,7 @@ struct type_data {
125129
};
126130
void (*set_self_py)(void *, PyObject *) noexcept;
127131
bool (*keep_shared_from_this_alive)(PyObject *) noexcept;
132+
void* (*upcast_hook)(PyObject *, const std::type_info *) noexcept;
128133
#if defined(Py_LIMITED_API)
129134
uint32_t dictoffset;
130135
uint32_t weaklistoffset;
@@ -183,6 +188,11 @@ NB_INLINE void type_extra_apply(type_init_data & t, const sig &s) {
183188
t.name = s.value;
184189
}
185190

191+
NB_INLINE void type_extra_apply(type_init_data &t, upcast_hook h) {
192+
t.flags |= (uint32_t) type_flags::has_upcast_hook;
193+
t.upcast_hook = h.hook;
194+
}
195+
186196
template <typename T>
187197
NB_INLINE void type_extra_apply(type_init_data &t, supplement<T>) {
188198
static_assert(std::is_trivially_default_constructible_v<T>,

src/nb_type.cpp

+18-1
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,8 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
10611061
is_generic = t->flags & (uint32_t) type_flags::is_generic,
10621062
intrusive_ptr = t->flags & (uint32_t) type_flags::intrusive_ptr,
10631063
has_shared_from_this = t->flags & (uint32_t) type_flags::has_shared_from_this,
1064-
has_signature = t->flags & (uint32_t) type_flags::has_signature;
1064+
has_signature = t->flags & (uint32_t) type_flags::has_signature,
1065+
has_upcast_hook = t->flags & (uint32_t) type_flags::has_upcast_hook;
10651066

10661067
const char *t_name = t->name;
10671068
if (has_signature)
@@ -1346,6 +1347,12 @@ PyObject *nb_type_new(const type_init_data *t) noexcept {
13461347
to->keep_shared_from_this_alive = tb->keep_shared_from_this_alive;
13471348
}
13481349

1350+
if (!has_upcast_hook && tb &&
1351+
(tb->flags & (uint32_t) type_flags::has_upcast_hook)) {
1352+
to->flags |= (uint32_t) type_flags::has_upcast_hook;
1353+
to->upcast_hook = tb->upcast_hook;
1354+
}
1355+
13491356
#if defined(Py_LIMITED_API)
13501357
to->vectorcall = type_vectorcall;
13511358
#else
@@ -1551,6 +1558,16 @@ bool nb_type_get(const std::type_info *cpp_type, PyObject *src, uint8_t flags,
15511558

15521559
return true;
15531560
}
1561+
1562+
// This is a nanobind type but not the right one; try an upcast hook
1563+
// if one was provided
1564+
if (t->flags & (uint32_t) type_flags::has_upcast_hook) {
1565+
void *ptr = t->upcast_hook(src, cpp_type);
1566+
if (ptr) {
1567+
*out = ptr;
1568+
return true;
1569+
}
1570+
}
15541571
}
15551572

15561573
// Try an implicit conversion as last resort (if possible & requested)

tests/test_classes.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -718,4 +718,26 @@ NB_MODULE(test_classes_ext, m) {
718718
.def_prop_ro_static("x", [](nb::handle /*unused*/) { return 42; });
719719
nb::class_<StaticPropertyOverride2, StaticPropertyOverride>(m, "StaticPropertyOverride2")
720720
.def_prop_ro_static("x", [](nb::handle /*unused*/) { return 43; });
721+
722+
struct MultA { int a = 10; };
723+
struct MultB { int b = 20; };
724+
struct MultD : MultA, MultB { int d = 30; };
725+
struct MultE : MultD { int e = 40; };
726+
727+
nb::class_<MultA>(m, "MultA").def(nb::init<>()).def_rw("a", &MultA::a);
728+
auto clsB = nb::class_<MultB>(m, "MultB").def(nb::init<>()).def_rw("b", &MultB::b);
729+
730+
auto try_D_to_B = [](PyObject *self_py, const std::type_info *target) noexcept -> void* {
731+
MultD *self = nb::inst_ptr<MultD>(self_py);
732+
if (*target == typeid(MultB)) {
733+
return static_cast<MultB*>(self);
734+
}
735+
return nullptr;
736+
};
737+
738+
auto clsD = nb::class_<MultD, MultA>(m, "MultD", nb::upcast_hook(try_D_to_B))
739+
.def(nb::init<>())
740+
.def_rw("d", &MultD::d);
741+
clsD.attr("b") = clsB.attr("b");
742+
nb::class_<MultE, MultD>(m, "MultE").def(nb::init<>()).def_rw("e", &MultE::e);
721743
}

tests/test_classes.py

+32
Original file line numberDiff line numberDiff line change
@@ -941,3 +941,35 @@ def my_init(self):
941941
def test49_static_property_override():
942942
assert t.StaticPropertyOverride.x == 42
943943
assert t.StaticPropertyOverride2.x == 43
944+
945+
def test50_i_cant_believe_its_not_multiple_inheritance(monkeypatch):
946+
objs = [t.MultB(), t.MultD(), t.MultE()]
947+
for i, obj in enumerate(objs):
948+
assert obj.b == 20
949+
obj.b += i
950+
try:
951+
assert obj.d == 30
952+
obj.d += 100 * i
953+
except AttributeError:
954+
if i != 0:
955+
raise
956+
957+
assert objs[0].b == 20
958+
assert objs[1].b == 21
959+
assert objs[2].b == 22
960+
assert objs[1].d == 130
961+
assert objs[2].d == 230
962+
963+
def patched_instancecheck(cls, inst, *, _orig=type(t.MultB).__instancecheck__):
964+
if _orig(t.MultD, inst) and cls is t.MultB:
965+
return True
966+
return _orig(cls, inst)
967+
968+
monkeypatch.setattr(type(t.MultB), "__instancecheck__", patched_instancecheck)
969+
assert isinstance(objs[0], t.MultB)
970+
assert not isinstance(objs[0], t.MultD)
971+
assert isinstance(objs[1], t.MultB)
972+
assert isinstance(objs[1], t.MultD)
973+
assert isinstance(objs[2], t.MultB)
974+
assert isinstance(objs[2], t.MultD)
975+
assert isinstance(objs[2], t.MultE)

0 commit comments

Comments
 (0)