Skip to content

Commit e67d934

Browse files
authored
nb::ndarray<...> implementation refactored (#721)
This PR refactors the ``nb::ndarray`` implementation to remove certain redundancies. In particular, there were duplicate code paths to process ``ndarray`` template parameters at compile time and at runtime, which are now merged. Significant edits to the documentation are intended to make nd-array bindings more approachable to newcomers. Finally, the refactor was also an opportunity to realize two usability improvements: 1. The constructor to return new nd-arrays from C++ now considers all template arguments: - **Memory order**: ``c_contig``, ``f_contig``. - **Shape**: ``nb::shape<3, 4, 5>``, etc. - **Device type**: ``nb::device::cpu``, ``nb::device::cuda``, etc. - **Framework**: ``nb::numpy``, ``nb::pytorch``, etc. - **Data type**: ``uint64_t``, ``std::complex<double>``, etc. Previously, only the **framework** and **data type** annotations were taken into account when returning nd-arrays, while all of them were examined when *accepting* arrays during overload resolution. This inconsistency was a repeated source of confusion among users. To give an example, the following now works out of the box without the need to redundantly specify the shape and strides to the ``Array`` constructor below: ```cpp using Array = nb::ndarray<float, nb::numpy, nb::shape<4, 4>, nb::f_contig>; struct Matrix4f { float m[4][4]; Array data() { return Array(m); } }; nb::class_<Matrix4f>(m, "Matrix4f") .def("data", &Matrix4f::data, nb::rv_policy::reference_internal); ``` 2. A new nd-array ``.cast()`` method forces the immediate creation of a Python object with the specified target framework and return value policy, while preserving the type signature in return values. This is useful to return temporaries (e.g. stack-allocated memory) from functions. There are two minor but potentially breaking changes: 1. The ndarray type caster now interprets the ``rv_policy::automatic_reference`` return value policy analogously to the ``rv_policy::automatic``, which means that it references a memory region when the user specifies an ``owner``, and it otherwise copies. This makes it safe to use the ``nb::cast()`` and ``nb::ndarray::cast()`` functions that use this policy as a default. 2. The ``nb::any_contig`` memory order annotation, which previously did nothing, now accepts C- or F-contiguous arrays and rejects non-contiguous ones. In both of these cases, the prior convention seems like it would cause bugs/breakage in practice. If nobody depends on this behavior, it should be OK to fix these without a major version bump. A small change without compatibility implications: the `owner` argument has a default `{}` argument again. I think this is reasonably safe because the `automatic_*` return value policies in nanobind will copy the input array if there isn't an owner. This effectively reverts a change from commit 937a1df.
1 parent 7f7e0c0 commit e67d934

File tree

13 files changed

+1163
-773
lines changed

13 files changed

+1163
-773
lines changed

docs/api_extra.rst

Lines changed: 125 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -644,8 +644,8 @@ N-dimensional array type
644644
------------------------
645645

646646
The following type can be used to exchange n-dimension arrays with frameworks
647-
like NumPy, PyTorch, Tensorflow, JAX, CuPy, and others. It requires an additional
648-
include directive:
647+
like NumPy, PyTorch, Tensorflow, JAX, CuPy, and others. It requires an
648+
additional include directive:
649649

650650
.. code-block:: cpp
651651
@@ -664,11 +664,36 @@ section <ndarrays>`.
664664

665665
.. cpp:class:: template <typename... Args> ndarray
666666

667+
.. cpp:type:: Scalar
668+
669+
The scalar type underlying the array (or ``void`` if not specified)
670+
667671
.. cpp:var:: static constexpr bool ReadOnly
668672

669-
A constant static boolean that is true if the array's data is read-only.
670-
This is determined by the class template arguments, not by any dynamic
671-
properties of the referenced array.
673+
A ``constexpr`` Boolean value that is ``true`` if the ndarray template
674+
arguments (`Args... <Args>`) include the ``nb::ro`` annotation or a
675+
``const``-qualified scalar type.
676+
677+
.. cpp:var:: static constexpr char Order
678+
679+
A ``constexpr`` character value set based on the ndarray template
680+
arguments (`Args... <Args>`). It equals
681+
682+
- ``'C'`` if :cpp:class:`c_contig` is specified,
683+
- ``'F'`` if :cpp:class:`f_contig` is specified,
684+
- ``'A'`` if :cpp:class:`any_contig` is specified,
685+
- ``'\0'`` otherwise.
686+
687+
.. cpp:var:: static constexpr int DeviceType
688+
689+
A ``constexpr`` integer value set to the device type ID extracted from
690+
the ndarray template arguments (`Args... <Args>`), or
691+
:cpp:struct:`device::none::value <device::none>` when none was specified.
692+
693+
.. cpp:type:: VoidPtr = std::conditional_t<ReadOnly, const void *, void *>
694+
695+
A potentially ``const``-qualified ``void*`` pointer type used by some
696+
of the ``ndarray`` constructors.
672697

673698
.. cpp:function:: ndarray() = default
674699

@@ -677,8 +702,8 @@ section <ndarrays>`.
677702
.. cpp:function:: template <typename... Args2> explicit ndarray(const ndarray<Args2...> &other)
678703

679704
Reinterpreting constructor that wraps an existing nd-array (parameterized
680-
by `Args`) into a new ndarray (parameterized by `Args2`). No copy or
681-
conversion is made.
705+
by `Args... <Args>`) into a new ndarray (parameterized by `Args2...
706+
<Args2>`). No copy or conversion is made.
682707

683708
Dropping parameters is always safe. For example, a function that
684709
returns different array types could call it to convert ``ndarray<T>`` to
@@ -708,37 +733,87 @@ section <ndarrays>`.
708733
Move assignment operator. Steals the referenced array without changing reference counts.
709734
Decreases the reference count of the previously referenced array and potentially destroy it.
710735

711-
.. cpp:function:: ndarray(void * data, size_t ndim, const size_t * shape, handle owner = nanobind::handle(), const int64_t * strides = nullptr, dlpack::dtype dtype = nanobind::dtype<Scalar>(), int32_t device_type = device::cpu::value, int32_t device_id = 0)
736+
.. _ndarray_dynamic_constructor:
737+
738+
.. cpp:function:: ndarray(VoidPtr data, const std::initializer_list<size_t> shape = { }, handle owner = { }, std::initializer_list<int64_t> strides = { }, dlpack::dtype dtype = nanobind::dtype<Scalar>(), int32_t device_type = DeviceType, int32_t device_id = 0, char order = Order)
712739

713-
Create an array wrapping an existing memory allocation. The following
714-
parameters can be specified:
740+
Create an array wrapping an existing memory allocation.
715741

716-
- `data`: pointer address of the memory region. When the ndarray is
717-
parameterized by a constant scalar type to indicate read-only access, a
718-
const pointer must be passed instead.
742+
Only the `data` parameter is strictly required, while some other
743+
parameters can be be inferred from static :cpp:class:`nb::ndarray\<...\>
744+
<ndarray>` template parameters.
719745

720-
- `ndim`: the number of dimensions.
746+
The parameters have the following meaning:
721747

722-
- `shape`: specifies the size along each axis. The referenced array must
723-
must have `ndim` entries.
748+
- `data`: a CPU/GPU/.. pointer to the memory region storing the array
749+
data.
750+
751+
When the array is parameterized by a ``const`` scalar type, or when it
752+
has a :cpp:class:`nb::ro <ro>` read-only annotation, a ``const``
753+
pointer can be passed here.
754+
755+
- `shape`: an initializer list that simultaneously specifies the number
756+
of dimensions and the size along each axis. If left at its default
757+
``{}``, the :cpp:class:`nb::shape <nanobind::shape>` template parameter
758+
will take precedence (if present).
724759

725760
- `owner`: if provided, the array will hold a reference to this object
726-
until it is destructed.
761+
until its destruction. This makes it possible to create zero-copy views
762+
into other data structures, while guaranteeing the memory safety of
763+
array accesses.
764+
765+
- `strides`: an initializer list explaining the layout of the data in
766+
memory. Each entry denotes the number of elements to jump over to
767+
advance to the next item along the associated axis.
727768

728-
- `strides` is optional; a value of ``nullptr`` implies C-style strides.
769+
`strides` must either have the same size as `shape` or be empty. In the
770+
latter case, strides are automatically computed according to the
771+
`order` parameter.
729772

730-
- `dtype` describes the data type (floating point, signed/unsigned
731-
integer) and bit depth.
773+
Note that strides in nanobind express *element counts* rather than
774+
*byte counts*. This convention differs from other frameworks (e.g.,
775+
NumPy) and is a consequence of the underlying `DLPack
776+
<https://github.com/dmlc/dlpack>`_ protocol.
732777

733-
- The `device_type` and `device_id` indicate the device and address
734-
space associated with the pointer `value`.
778+
- `dtype` describes the numeric data type of array elements (e.g.,
779+
floating point, signed/unsigned integer) and their bit depth.
735780

736-
.. cpp:function:: ndarray(void * data, const std::initializer_list<size_t> shape, handle owner = nanobind::handle(), std::initializer_list<int64_t> strides = { }, dlpack::dtype dtype = nanobind::dtype<Scalar>(), int32_t device_type = device::cpu::value, int32_t device_id = 0)
781+
You can use the :cpp:func:`nb::dtype\<T\>() <nanobind::dtype>` function to obtain the right
782+
value for a given type.
737783

738-
Alternative form of the above constructor, which accepts the ``shape``
739-
and ``strides`` arguments using a ``std::initializer_list``. It
740-
automatically infers the value of ``ndim`` based on the size of
741-
``shape``.
784+
- `device_type` and `device_id` specify where the array data is stored.
785+
The `device_type` must be an enumerant like
786+
:cpp:class:`nb::device::cuda::value <device::cuda>`, while the meaning
787+
of the device ID is unspecified and platform-dependent.
788+
789+
Note that the `device_id` is set to ``0`` by default and cannot be
790+
inferred by nanobind. If your extension creates arrays on multiple
791+
different compute accelerators, you *must* provide this parameter.
792+
793+
- The `order` parameter denotes the coefficient order in memory and is only
794+
relevant when `strides` is empty. Specify ``'C'`` for C-style or ``'F'``
795+
for Fortran-style. When this parameter is not explicitly specified, the
796+
implementation uses the order specified as an ndarray template
797+
argument, or C-style order as a fallback.
798+
799+
Both ``strides`` and ``shape`` will be copied by the constructor, hence
800+
the targets of these initializer lists do not need to remain valid
801+
following the constructor call.
802+
803+
.. warning::
804+
805+
The Python *global interpreter lock* (GIL) must be held when calling
806+
this function.
807+
808+
.. cpp:function:: ndarray(VoidPtr data, size_t ndim, const size_t * shape, handle owner, const int64_t * strides = nullptr, dlpack::dtype dtype = nanobind::dtype<Scalar>(), int device_type = DeviceType, int device_id = 0, char order = Order)
809+
810+
Alternative form of the above constructor, which accepts the `shape`
811+
and `strides` arguments using pointers instead of initializer lists.
812+
The number of dimensions must be specified via the `ndim` parameter
813+
in this case.
814+
815+
See the previous constructor for details, the remaining behavior is
816+
identical.
742817

743818
.. cpp:function:: dlpack::dtype dtype() const
744819

@@ -788,13 +863,13 @@ section <ndarrays>`.
788863

789864
Check whether the array is in a valid state.
790865

791-
.. cpp:function:: int32_t device_type() const
866+
.. cpp:function:: int device_type() const
792867

793868
ID denoting the type of device hosting the array. This will match the
794869
``value`` field of a device class, such as :cpp:class:`device::cpu::value
795870
<device::cpu>` or :cpp:class:`device::cuda::value <device::cuda>`.
796871

797-
.. cpp:function:: int32_t device_id() const
872+
.. cpp:function:: int device_id() const
798873

799874
In a multi-device/GPU setup, this function returns the ID of the device
800875
storing the array.
@@ -804,15 +879,18 @@ section <ndarrays>`.
804879
Return a pointer to the array data.
805880
If :cpp:var:`ReadOnly` is true, a pointer-to-const is returned.
806881

807-
.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)
882+
.. cpp:function:: template <typename... Args2> auto& operator()(Args2... indices)
808883

809884
Return a reference to the element stored at the provided index/indices.
810885
If :cpp:var:`ReadOnly` is true, a reference-to-const is returned.
811-
Note that ``sizeof(Ts)`` must match :cpp:func:`ndim()`.
886+
Note that ``sizeof...(Args2)`` must match :cpp:func:`ndim()`.
812887

813888
This accessor is only available when the scalar type and array dimension
814889
were specified as template parameters.
815890

891+
This function should only be used when the array storage is accessible
892+
through the CPU's virtual memory address space.
893+
816894
.. cpp:function:: template <typename... Extra> auto view()
817895

818896
Returns an nd-array view that is optimized for fast array access on the
@@ -824,6 +902,18 @@ section <ndarrays>`.
824902
``shape()``, ``stride()``, and ``operator()`` following the conventions
825903
of the `ndarray` type.
826904

905+
.. cpp:function:: auto cast(rv_policy policy = rv_policy::automatic_reference, handle parent = {})
906+
907+
The expression ``array.cast(policy, parent)`` is almost equivalent to
908+
:cpp:func:`nb::cast(array, policy, parent) <cast>`.
909+
910+
The main difference is that the return type of :cpp:func:`nb::cast
911+
<cast>` is :cpp:class:`nb::object <object>`, which renders as a rather
912+
non-descriptive ``object`` in Python bindings. The ``.cast()`` method
913+
returns a custom wrapper type that still derives from
914+
:cpp:class:`nb::object <object>`, but whose type signature in bindings
915+
reproduces that of the original nd-array.
916+
827917
Data types
828918
^^^^^^^^^^
829919

@@ -947,7 +1037,10 @@ Contiguity
9471037

9481038
.. cpp:class:: any_contig
9491039

950-
Don't place any demands on array contiguity (the default).
1040+
Accept both C- and F-contiguous arrays.
1041+
1042+
If you prefer not to require contiguity, simply do not provide any of the
1043+
``*_contig`` template parameters listed above.
9511044

9521045
Device type
9531046
+++++++++++

docs/changelog.rst

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,15 @@ It also has a separate ABI version that is *not* subject to semantic
1010
versioning.
1111

1212
The ABI version is relevant whenever a type binding from one extension module
13-
should be visible in another (also nanobind-based) extension module. In this
13+
should be visible in another nanobind-based extension module. In this
1414
case, both modules must use the same nanobind ABI version, or they will be
1515
isolated from each other. Releases that don't explicitly mention an ABI version
1616
below inherit that of the preceding release.
1717

1818
Version 2.2.0 (TBA)
1919
-------------------
2020

21-
- The NVIDIA CUDA compiler (``nvcc``) is now explicitly supported and included
22-
in nanobind's CI test suite.
23-
24-
- nanobind has always used `PEP 590 vector calls
21+
* nanobind has always used `PEP 590 vector calls
2522
<https://www.python.org/dev/peps/pep-0590>`__ to efficiently dispatch calls
2623
to function and method bindings, but it lacked the ability to do so for
2724
constructors (e.g., ``MyType(arg1, arg2, ...)``).
@@ -41,6 +38,63 @@ Version 2.2.0 (TBA)
4138
with :cpp:class:`nb::is_arithmetic() <is_flag>` creates enumerations deriving
4239
from :py:class:`enum.IntFlag`.
4340

41+
* A refactor of :cpp:class:`nb::ndarray\<...\> <ndarray>` was an opportunity to
42+
realize two usability improvements:
43+
44+
1. The constructor used to return new nd-arrays from C++ now considers
45+
all template arguments:
46+
47+
- **Memory order**: :cpp:class:`c_contig`, :cpp:class:`f_contig`.
48+
- **Shape**: :cpp:class:`nb::shape\<3, 4, 5\> <shape>`, etc.
49+
- **Device type**: :cpp:class:`nb::device::cpu <device::cpu>`, :cpp:class:`nb::device::cuda <device::cuda>`, etc.
50+
- **Framework**: :cpp:class:`nb::numpy <numpy>`, :cpp:class:`nb::pytorch <pytorch>`, etc.
51+
- **Data type**: ``uint64_t``, ``std::complex<double>``, etc.
52+
53+
Previously, only the **framework** and **data type** annotations were
54+
taken into account when returning nd-arrays, while all of them were
55+
examined when *accepting* arrays during overload resolution. This
56+
inconsistency was a repeated source of confusion among users.
57+
58+
To give an example, the following now works out of the box without the
59+
need to redundantly specify the shape and strides to the ``Array``
60+
constructor below:
61+
62+
.. code-block:: cpp
63+
64+
using Array = nb::ndarray<float, nb::numpy, nb::shape<4, 4>, nb::f_contig>;
65+
66+
struct Matrix4f {
67+
float m[4][4];
68+
Array data() { return Array(m); }
69+
};
70+
71+
nb::class_<Matrix4f>(m, "Matrix4f")
72+
.def("data", &Matrix4f::data, nb::rv_policy::reference_internal);
73+
74+
2. A new nd-array :cpp:func:`.cast() <ndarray::cast>` method forces the
75+
immediate creation of a Python object with the specified target framework
76+
and return value policy, while preserving the type signature in return
77+
values. This is useful to :ref:`return temporaries (e.g. stack-allocated
78+
memory) <ndarray-temporaries>` from functions.
79+
80+
There are two minor but potentially breaking changes:
81+
82+
1. The ndarray type caster now interprets the
83+
:cpp:enumerator:`nb::rv_policy::automatic_reference
84+
<rv_policy::automatic_reference>` return value policy analogously to the
85+
:cpp:enumerator:`nb::rv_policy::automatic <rv_policy::automatic>`, which
86+
means that it references a memory region when the user specifies an
87+
``owner``, and it otherwise copies. This makes it safe to use the
88+
:cpp:func:`nb::cast() <cast>` and :cpp:func:`nb::ndarray::cast()
89+
<ndarray::cast>` functions that use this policy as a default.
90+
91+
2. The :cpp:class:`nb::any_contig <any_contig>` memory order annotation,
92+
which previously did nothing, now accepts C- or F-contiguous arrays and
93+
rejects non-contiguous ones.
94+
95+
- The NVIDIA CUDA compiler (``nvcc``) is now explicitly supported and included
96+
in nanobind's CI test suite.
97+
4498
* Added support for return value policy customization to the type casters of
4599
``Eigen::Ref<...>`` and ``Eigen::Map<...>`` (commit `67316e
46100
<https://github.com/wjakob/nanobind/commit/67316eb88955a15e8e89a57ce9a53d8d66263287>`__).
@@ -61,25 +115,21 @@ Version 2.2.0 (TBA)
61115
* Fixed implicit conversion of complex nd-arrays. (issue `#709
62116
<https://github.com/wjakob/nanobind/issues/709>`__)
63117

64-
* Minor fixes and improvements (PR `#696
65-
<https://github.com/wjakob/nanobind/pull/696>`__, `#693
66-
<https://github.com/wjakob/nanobind/pull/693>`__, `#675
67-
<https://github.com/wjakob/nanobind/pull/675>`__, commit `75d259
68-
<https://github.com/wjakob/nanobind/commit/75d259c7c16db9586e5cd3aa4715e09a25e76d83>`__).
69-
70118
* Casting via :cpp:func:`nb::cast <cast>` can now specify an owner object for
71119
use with the :cpp:enumerator:`nb::rv_policy::reference_internal
72-
<rv_policy::reference_internal>` return value policy (PR `#667
73-
<https://github.com/wjakob/nanobind/pull/667>`__) #667
120+
<rv_policy::reference_internal>` return value policy (PR `#667
121+
<https://github.com/wjakob/nanobind/pull/667>`__).
74122

75-
* The ``std::optional<T>`` type caster is now implemented so that it can also
76-
accommodate other frameworks such as Boost, Abseil, etc. (PR `#675
77-
<https://github.com/wjakob/nanobind/pull/675>`__)
123+
* The ``std::optional<T>`` type caster is now implemented in such a way that it
124+
can also accommodate non-STL frameworks, such as Boost, Abseil, etc. (PR
125+
`#675 <https://github.com/wjakob/nanobind/pull/675>`__)
78126

79127
* ABI version 15.
80128

81-
* Minor fixes and improvements.
82-
129+
* Minor fixes and improvements (PR `#696
130+
<https://github.com/wjakob/nanobind/pull/696>`__, `#693
131+
<https://github.com/wjakob/nanobind/pull/693>`__, commit `75d259
132+
<https://github.com/wjakob/nanobind/commit/75d259c7c16db9586e5cd3aa4715e09a25e76d83>`__).
83133

84134
Version 2.1.0 (Aug 11, 2024)
85135
----------------------------
@@ -572,15 +622,15 @@ New features
572622

573623
* Several :cpp:class:`nb::ndarray\<..\> <ndarray>` improvements:
574624

575-
1. CPU loops involving nanobind ndarrays weren't getting properly vectorized.
625+
1. CPU loops involving nanobind nd-arrays weren't getting properly vectorized.
576626
This release of nanobind adds *views*, which provide an efficient
577627
abstraction that enables better code generation. See the documentation
578628
section on :ref:`array views <ndarray-views>` for details.
579629
(commit `8f602e
580630
<https://github.com/wjakob/nanobind/commit/8f602e187b0634e1df13ba370352cf092e9042c0>`__).
581631

582632
2. Added support for nonstandard arithmetic types (e.g., ``__int128`` or
583-
``__fp16``) in ndarrays. See the :ref:`documentation section
633+
``__fp16``) in nd-arrays. See the :ref:`documentation section
584634
<ndarray-nonstandard>` for details. (commit `49eab2
585635
<https://github.com/wjakob/nanobind/commit/49eab2845530f84a1f029c5c1c5541ab3c1f9adc>`__).
586636

@@ -589,7 +639,7 @@ New features
589639
:cpp:class:`nb::ndim\<3\> <ndim>`. (commit `1350a5
590640
<https://github.com/wjakob/nanobind/commit/1350a5e15b28e80ffc2130a779f3b8c559ddb620>`__).
591641

592-
4. Added an explicit constructor that can be used to add or remove ndarray
642+
4. Added an explicit constructor that can be used to add or remove nd-array
593643
constraints. (commit `a1ac207
594644
<https://github.com/wjakob/nanobind/commit/a1ac207ab82206b8e50fe456f577c02270014fb3>`__).
595645

0 commit comments

Comments
 (0)