@@ -517,15 +517,17 @@ cdef class Tree:
517
517
518
518
node_ndarray = d[' nodes' ]
519
519
value_ndarray = d[' values' ]
520
-
521
- value_shape = (node_ndarray.shape[0 ], self .n_outputs)
520
+
522
521
if (node_ndarray.ndim != 1 or
523
522
node_ndarray.dtype != NODE_DTYPE or
524
- not node_ndarray.flags.c_contiguous or
525
- value_ndarray.shape != value_shape or
523
+ not node_ndarray.flags.c_contiguous):
524
+ raise ValueError (' Did not recognise loaded array layout for `node_ndarray`' )
525
+
526
+ value_shape = (node_ndarray.shape[0 ], self .n_outputs, self .max_n_classes)
527
+ if (value_ndarray.shape != value_shape or
526
528
not value_ndarray.flags.c_contiguous or
527
529
value_ndarray.dtype != np.float64):
528
- raise ValueError (' Did not recognise loaded array layout' )
530
+ raise ValueError (' Did not recognise loaded array layout for `value_ndarray` ' )
529
531
530
532
self .capacity = node_ndarray.shape[0 ]
531
533
if self ._resize_c(self .capacity) != 0 :
@@ -541,15 +543,15 @@ cdef class Tree:
541
543
if (jac_ndarray.shape != jac_shape or
542
544
not jac_ndarray.flags.c_contiguous or
543
545
jac_ndarray.dtype != np.float64):
544
- raise ValueError (' Did not recognise loaded array layout' )
546
+ raise ValueError (' Did not recognise loaded array layout for `jac_ndarray` ' )
545
547
jac = memcpy(self .jac, (< np.ndarray> jac_ndarray).data,
546
548
self .capacity * self .jac_stride * sizeof(double ))
547
549
precond_ndarray = d[' precond' ]
548
550
precond_shape = (node_ndarray.shape[0 ], self .n_outputs)
549
551
if (precond_ndarray.shape != precond_shape or
550
552
not precond_ndarray.flags.c_contiguous or
551
553
precond_ndarray.dtype != np.float64):
552
- raise ValueError (' Did not recognise loaded array layout' )
554
+ raise ValueError (' Did not recognise loaded array layout for `precond_ndarray` ' )
553
555
precond = memcpy(self .precond, (< np.ndarray> precond_ndarray).data,
554
556
self .capacity * self .precond_stride * sizeof(double ))
555
557
@@ -917,7 +919,7 @@ cdef class Tree:
917
919
cdef np.npy_intp shape[3 ]
918
920
shape[0 ] = < np.npy_intp> self .node_count
919
921
shape[1 ] = < np.npy_intp> self .n_outputs
920
- shape[2 ] = 1
922
+ shape[2 ] = < np.npy_intp > self .max_n_classes
921
923
cdef np.ndarray arr
922
924
arr = np.PyArray_SimpleNewFromData(3 , shape, np.NPY_DOUBLE, self .value)
923
925
Py_INCREF(self )
0 commit comments