|
| 1 | +/* |
| 2 | + nanobind/eigen/sparse.h: type casters for sparse Eigen matrices |
| 3 | +
|
| 4 | + Copyright (c) 2023 Henri Menke and Wenzel Jakob |
| 5 | +
|
| 6 | + All rights reserved. Use of this source code is governed by a |
| 7 | + BSD-style license that can be found in the LICENSE file. |
| 8 | +*/ |
| 9 | + |
| 10 | +#pragma once |
| 11 | + |
| 12 | +#include <nanobind/ndarray.h> |
| 13 | +#include <nanobind/eigen/dense.h> |
| 14 | +#include <Eigen/SparseCore> |
| 15 | + |
| 16 | +#include <memory> |
| 17 | +#include <type_traits> |
| 18 | +#include <utility> |
| 19 | + |
| 20 | +NAMESPACE_BEGIN(NB_NAMESPACE) |
| 21 | + |
| 22 | +NAMESPACE_BEGIN(detail) |
| 23 | + |
| 24 | +/// Detect Eigen::SparseMatrix |
| 25 | +template <typename T> constexpr bool is_eigen_sparse_matrix_v = |
| 26 | + is_eigen_sparse_v<T> && |
| 27 | + !std::is_base_of_v<Eigen::SparseMapBase<T, Eigen::ReadOnlyAccessors>, T>; |
| 28 | + |
| 29 | + |
| 30 | +/// Caster for Eigen::SparseMatrix |
| 31 | +template <typename T> struct type_caster<T, enable_if_t<is_eigen_sparse_matrix_v<T>>> { |
| 32 | + using Scalar = typename T::Scalar; |
| 33 | + using StorageIndex = typename T::StorageIndex; |
| 34 | + using Index = typename T::Index; |
| 35 | + using SparseMap = Eigen::Map<T>; |
| 36 | + |
| 37 | + static_assert(std::is_same_v<T, Eigen::SparseMatrix<Scalar, T::Options, StorageIndex>>, |
| 38 | + "nanobind: Eigen sparse caster only implemented for matrices"); |
| 39 | + |
| 40 | + static constexpr bool row_major = T::IsRowMajor; |
| 41 | + |
| 42 | + using ScalarNDArray = ndarray<numpy, Scalar, shape<any>>; |
| 43 | + using StorageIndexNDArray = ndarray<numpy, StorageIndex, shape<any>>; |
| 44 | + |
| 45 | + using ScalarCaster = make_caster<ScalarNDArray>; |
| 46 | + using StorageIndexCaster = make_caster<StorageIndexNDArray>; |
| 47 | + |
| 48 | + NB_TYPE_CASTER(T, const_name<row_major>("scipy.sparse.csr_matrix[", |
| 49 | + "scipy.sparse.csc_matrix[") |
| 50 | + + make_caster<Scalar>::Name + const_name("]")); |
| 51 | + |
| 52 | + ScalarCaster data_caster; |
| 53 | + StorageIndexCaster indices_caster, indptr_caster; |
| 54 | + |
| 55 | + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { |
| 56 | + object obj = borrow(src); |
| 57 | + try { |
| 58 | + object matrix_type = module_::import_("scipy.sparse").attr(row_major ? "csr_matrix" : "csc_matrix"); |
| 59 | + if (!obj.type().is(matrix_type)) |
| 60 | + obj = matrix_type(obj); |
| 61 | + } catch (const python_error &) { |
| 62 | + return false; |
| 63 | + } |
| 64 | + |
| 65 | + if (object data_o = obj.attr("data"); !data_caster.from_python(data_o, flags, cleanup)) |
| 66 | + return false; |
| 67 | + ScalarNDArray& values = data_caster.value; |
| 68 | + |
| 69 | + if (object indices_o = obj.attr("indices"); !indices_caster.from_python(indices_o, flags, cleanup)) |
| 70 | + return false; |
| 71 | + StorageIndexNDArray& inner_indices = indices_caster.value; |
| 72 | + |
| 73 | + if (object indptr_o = obj.attr("indptr"); !indptr_caster.from_python(indptr_o, flags, cleanup)) |
| 74 | + return false; |
| 75 | + StorageIndexNDArray& outer_indices = indptr_caster.value; |
| 76 | + |
| 77 | + object shape_o = obj.attr("shape"), nnz_o = obj.attr("nnz"); |
| 78 | + Index rows, cols, nnz; |
| 79 | + try { |
| 80 | + if (len(shape_o) != 2) |
| 81 | + return false; |
| 82 | + rows = cast<Index>(shape_o[0]); |
| 83 | + cols = cast<Index>(shape_o[1]); |
| 84 | + nnz = cast<Index>(nnz_o); |
| 85 | + } catch (const python_error &) { |
| 86 | + return false; |
| 87 | + } |
| 88 | + |
| 89 | + value = SparseMap(rows, cols, nnz, outer_indices.data(), inner_indices.data(), values.data()); |
| 90 | + |
| 91 | + return true; |
| 92 | + } |
| 93 | + |
| 94 | + static handle from_cpp(T &&v, rv_policy policy, cleanup_list *cleanup) noexcept { |
| 95 | + if (policy == rv_policy::automatic || |
| 96 | + policy == rv_policy::automatic_reference) |
| 97 | + policy = rv_policy::move; |
| 98 | + |
| 99 | + return from_cpp((const T &) v, policy, cleanup); |
| 100 | + } |
| 101 | + |
| 102 | + static handle from_cpp(const T &v, rv_policy policy, cleanup_list *) noexcept { |
| 103 | + if (!v.isCompressed()) { |
| 104 | + PyErr_SetString(PyExc_ValueError, |
| 105 | + "nanobind: unable to return an Eigen sparse matrix that is not in a compressed format. " |
| 106 | + "Please call `.makeCompressed()` before returning the value on the C++ end."); |
| 107 | + return handle(); |
| 108 | + } |
| 109 | + |
| 110 | + object matrix_type; |
| 111 | + try { |
| 112 | + matrix_type = module_::import_("scipy.sparse").attr(row_major ? "csr_matrix" : "csc_matrix"); |
| 113 | + } catch (python_error &e) { |
| 114 | + e.restore(); |
| 115 | + return handle(); |
| 116 | + } |
| 117 | + |
| 118 | + const Index rows = v.rows(); |
| 119 | + const Index cols = v.cols(); |
| 120 | + const size_t data_shape[] = { (size_t)v.nonZeros() }; |
| 121 | + const size_t outer_indices_shape[] = { (size_t)((row_major ? rows : cols) + 1) }; |
| 122 | + |
| 123 | + T *src = std::addressof(const_cast<T &>(v)); |
| 124 | + object owner; |
| 125 | + if (policy == rv_policy::move) { |
| 126 | + src = new T(std::move(v)); |
| 127 | + owner = capsule(src, [](void *p) noexcept { delete (T *) p; }); |
| 128 | + } |
| 129 | + |
| 130 | + ScalarNDArray data(src->valuePtr(), 1, data_shape, owner); |
| 131 | + StorageIndexNDArray outer_indices(src->outerIndexPtr(), 1, outer_indices_shape, owner); |
| 132 | + StorageIndexNDArray inner_indices(src->innerIndexPtr(), 1, data_shape, owner); |
| 133 | + |
| 134 | + try { |
| 135 | + return matrix_type(make_tuple( |
| 136 | + std::move(data), std::move(inner_indices), std::move(outer_indices)), |
| 137 | + make_tuple(rows, cols)) |
| 138 | + .release(); |
| 139 | + } catch (python_error &e) { |
| 140 | + e.restore(); |
| 141 | + return handle(); |
| 142 | + } |
| 143 | + } |
| 144 | +}; |
| 145 | + |
| 146 | + |
| 147 | +/// Caster for Eigen::Map<Eigen::SparseMatrix> |
| 148 | +template <typename T> |
| 149 | +struct type_caster<Eigen::Map<T>, enable_if_t<is_eigen_sparse_matrix_v<T>>> { |
| 150 | + using Map = Eigen::Map<T>; |
| 151 | + using SparseMatrixCaster = type_caster<T>; |
| 152 | + static constexpr auto Name = SparseMatrixCaster::Name; |
| 153 | + template <typename T_> using Cast = Map; |
| 154 | + |
| 155 | + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept = delete; |
| 156 | + |
| 157 | + static handle from_cpp(const Map &v, rv_policy policy, cleanup_list *cleanup) noexcept = delete; |
| 158 | +}; |
| 159 | + |
| 160 | + |
| 161 | +/// Caster for Eigen::Ref<Eigen::SparseMatrix> |
| 162 | +template <typename T, int Options> |
| 163 | +struct type_caster<Eigen::Ref<T, Options>, enable_if_t<is_eigen_sparse_matrix_v<T>>> { |
| 164 | + using Ref = Eigen::Ref<T, Options>; |
| 165 | + using Map = Eigen::Map<T, Options>; |
| 166 | + using MapCaster = make_caster<Map>; |
| 167 | + static constexpr auto Name = MapCaster::Name; |
| 168 | + template <typename T_> using Cast = Ref; |
| 169 | + |
| 170 | + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept = delete; |
| 171 | + |
| 172 | + static handle from_cpp(const Ref &v, rv_policy policy, cleanup_list *cleanup) noexcept = delete; |
| 173 | +}; |
| 174 | + |
| 175 | +NAMESPACE_END(detail) |
| 176 | + |
| 177 | +NAMESPACE_END(NB_NAMESPACE) |
0 commit comments