Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple Python class inheritance #926

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ test_*_ext*.so
test_*_ext*.pyd
py\.typed
.mypy_cache
.vscode/
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,28 @@ Please see the following links for tutorial and reference documentation in
[HTML](https://nanobind.readthedocs.io/en/latest/) and
[PDF](https://nanobind.readthedocs.io/_/downloads/en/latest/pdf/) formats.

## Development of nanobind

To compile the project and run the tests:

```bash
# If Linux, install Eigen
sudo apt-get -y install libeigen3-dev

# Python dev dependencies
python -m pip install pytest typing_extension

# Configure cmake build directory
cmake -S . -B build

# Build C++ project
cmake --build build -j 2

# Run tests
cd build
python -m pytest
```

## License and attribution

All material in this repository is licensed under a three-clause [BSD
Expand Down
48 changes: 42 additions & 6 deletions src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "nb_internals.h"
#include "nb_ft.h"
#include <iostream>

#if defined(_MSC_VER)
# pragma warning(disable: 4706) // assignment within conditional expression
Expand Down Expand Up @@ -460,26 +461,61 @@ static int nb_type_init(PyObject *self, PyObject *args, PyObject *kwds) {
}

PyObject *bases = NB_TUPLE_GET_ITEM(args, 1);
if (!PyTuple_CheckExact(bases) || NB_TUPLE_GET_SIZE(bases) != 1) {
if (!PyTuple_CheckExact(bases)) {
PyErr_SetString(PyExc_RuntimeError,
"nb_type_init(): invalid number of bases!");
"nb_type_init(): expected a base type object!");
return -1;
}

PyObject *base = NB_TUPLE_GET_ITEM(bases, 0);
if (!PyType_Check(base)) {
PyObject *nb_base = nullptr;
Py_ssize_t nb_base_index = 0;
for (Py_ssize_t i = 0; i < NB_TUPLE_GET_SIZE(bases); i++) {
PyObject *base = NB_TUPLE_GET_ITEM(bases, i);
if (PyType_Check(base)) {
if (nb_type_check(base)) {
if (nb_base) {
PyErr_SetString(PyExc_TypeError, "nb_type_init(): multiple inheritance of multiple nanobound classes are not allowed!");
return -1;
}
nb_base = base;
nb_base_index = i;
}
}
}

if (!nb_base) {
PyErr_SetString(PyExc_RuntimeError, "nb_type_init(): expected a base type object!");
return -1;
}

type_data *t_b = nb_type_data((PyTypeObject *) base);
// Reorder the bases tuple to ensure that the nanobind base is at the front
PyObject *bases_new = PyTuple_New(NB_TUPLE_GET_SIZE(bases));
NB_TUPLE_SET_ITEM(bases_new, 0, nb_base); // Set base at position 0
for (Py_ssize_t i = 0; i < NB_TUPLE_GET_SIZE(bases); i++) {
if (i != nb_base_index) {
PyObject *base = NB_TUPLE_GET_ITEM(bases, i);
PyTuple_SET_ITEM(bases_new, i < nb_base_index ? i + 1 : i, base);
}
}

// Create a new args tuple with the reordered bases
PyObject *args_new = PyTuple_New(NB_TUPLE_GET_SIZE(args));
for (Py_ssize_t i = 0; i < NB_TUPLE_GET_SIZE(args); i++) {
if (i != 1) {
PyTuple_SET_ITEM(args_new, i, NB_TUPLE_GET_ITEM(args, i));
} else {
PyTuple_SET_ITEM(args_new, i, bases_new);
}
}

type_data *t_b = nb_type_data((PyTypeObject *) NB_TUPLE_GET_ITEM(bases_new, 0));
if (t_b->flags & (uint32_t) type_flags::is_final) {
PyErr_Format(PyExc_TypeError, "The type '%s' prohibits subclassing!",
t_b->name);
return -1;
}

int rv = NB_SLOT(PyType_Type, tp_init)(self, args, kwds);
int rv = NB_SLOT(PyType_Type, tp_init)(self, args_new, kwds);
if (rv)
return rv;

Expand Down
46 changes: 46 additions & 0 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,3 +941,49 @@ def my_init(self):
def test49_static_property_override():
assert t.StaticPropertyOverride.x == 42
assert t.StaticPropertyOverride2.x == 43

def test50_multiple_inheritance_nb_first_base():
"""Test multiple inheritance with a NC class first in the inheritance list."""
class Python:
def name(self):
return "PyClass"

class SausageDog(t.Animal, Python):
def __init__(self):
super().__init__()

def who(self):
return "SausageDog"

sd = SausageDog()
assert isinstance(sd, SausageDog)
assert isinstance(sd, t.Animal)
assert isinstance(sd, Python)
assert sd.name() == "Animal"
assert sd.who() == "SausageDog"

def test51_multiple_inheritance_py_first_base():
"""Test multiple inheritance with a Python class first in the inheritance list."""
class Python:
def name(self):
return "Python"

class SausageDog(Python, t.Animal):
def __init__(self):
super().__init__()

def who(self):
return "SausageDog"

sd = SausageDog()
assert isinstance(sd, SausageDog)
assert isinstance(sd, t.Animal)
assert isinstance(sd, Python)
assert sd.name() == "Python"
assert sd.who() == "SausageDog"

def test52_multiple_inheritance_checks():
"""Test checks to prevent multiple nb class inheritance."""
with pytest.raises(TypeError):
class A(t.Animal, t.Foo):
pass
Loading