Skip to content

Commit b92a1c7

Browse files
committed
fix: Use modern numpy/cython functions to manage changing array objects
They are read-only in modern numpy Signed-off-by: bluedrink9 <[email protected]>
1 parent 2c0667f commit b92a1c7

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

econml/tree/_tree.pyx

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ cdef extern from "numpy/arrayobject.h":
3838
np.npy_intp* strides,
3939
void* data, int flags, object obj)
4040

41+
cdef extern from "numpy/arrayobject.h":
42+
int PyArray_SetBaseObject(np.ndarray arr, PyObject *obj) except -1
43+
4144
# =============================================================================
4245
# Types and constants
4346
# =============================================================================
@@ -179,10 +182,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
179182
sample_weight_ptr = <DOUBLE_t*> sample_weight.data
180183

181184
# Initial capacity
182-
cdef int init_capacity
185+
cdef SIZE_t init_capacity
183186

184187
if tree.max_depth <= 10:
185-
init_capacity = (2 ** (tree.max_depth + 1)) - 1
188+
init_capacity = <SIZE_t>((2 ** (tree.max_depth + 1)) - 1)
186189
else:
187190
init_capacity = 2047
188191

@@ -923,8 +926,7 @@ cdef class Tree:
923926
shape[2] = <np.npy_intp> self.max_n_classes
924927
cdef np.ndarray arr
925928
arr = np.PyArray_SimpleNewFromData(3, shape, np.NPY_DOUBLE, self.value)
926-
Py_INCREF(self)
927-
arr.base = <PyObject*> self
929+
PyArray_SetBaseObject(arr, <PyObject*> self)
928930
return arr
929931

930932
cdef np.ndarray _get_jac_ndarray(self):
@@ -938,7 +940,7 @@ cdef class Tree:
938940
cdef np.ndarray arr
939941
arr = np.PyArray_SimpleNewFromData(2, shape, np.NPY_DOUBLE, self.jac)
940942
Py_INCREF(self)
941-
arr.base = <PyObject*> self
943+
PyArray_SetBaseObject(arr, <PyObject*> self)
942944
return arr
943945

944946
cdef np.ndarray _get_precond_ndarray(self):
@@ -952,7 +954,7 @@ cdef class Tree:
952954
cdef np.ndarray arr
953955
arr = np.PyArray_SimpleNewFromData(2, shape, np.NPY_DOUBLE, self.precond)
954956
Py_INCREF(self)
955-
arr.base = <PyObject*> self
957+
PyArray_SetBaseObject(arr, <PyObject*> self)
956958
return arr
957959

958960
cdef np.ndarray _get_node_ndarray(self):
@@ -970,7 +972,7 @@ cdef class Tree:
970972
arr = PyArray_NewFromDescr(<PyTypeObject *> np.ndarray,
971973
<np.dtype> NODE_DTYPE, 1, shape,
972974
strides, <void*> self.nodes,
973-
np.NPY_DEFAULT, None)
975+
np.NPY_ARRAY_DEFAULT, None)
974976
Py_INCREF(self)
975-
arr.base = <PyObject*> self
977+
PyArray_SetBaseObject(arr, <PyObject*> self)
976978
return arr

0 commit comments

Comments
 (0)