Skip to content

Commit a66c0b2

Browse files
authored
fixed pickling of grf (#384)
* fixed pickling of grf * changed it to using the self.max_values for even better consistency with sklearn logic for future easier extendibility.
1 parent 6686c56 commit a66c0b2

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

econml/tests/test_grf_python.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import pandas as pd
1010
import pytest
11+
import joblib
1112
from econml.grf import RegressionForest, CausalForest, CausalIVForest, MultiOutputGRF
1213
from econml.utilities import cross_product
1314
from copy import deepcopy
@@ -688,3 +689,20 @@ def test_multioutput(self,):
688689
np.testing.assert_allclose(imps[0, :], imps[1, :])
689690

690691
return
692+
693+
def test_pickling(self,):
694+
695+
n_features = 2
696+
n = 10
697+
random_state = 123
698+
X, y, _ = self._get_regression_data(n, n_features, random_state)
699+
700+
forest = RegressionForest(n_estimators=4, warm_start=True, random_state=123).fit(X, y)
701+
forest.n_estimators = 8
702+
forest.fit(X, y)
703+
pred1 = forest.predict(X)
704+
705+
joblib.dump(forest, 'forest.jbl')
706+
loaded_forest = joblib.load('forest.jbl')
707+
np.testing.assert_equal(loaded_forest.n_estimators, forest.n_estimators)
708+
np.testing.assert_allclose(loaded_forest.predict(X), pred1)

econml/tree/_tree.pyx

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,15 +517,17 @@ cdef class Tree:
517517

518518
node_ndarray = d['nodes']
519519
value_ndarray = d['values']
520-
521-
value_shape = (node_ndarray.shape[0], self.n_outputs)
520+
522521
if (node_ndarray.ndim != 1 or
523522
node_ndarray.dtype != NODE_DTYPE or
524-
not node_ndarray.flags.c_contiguous or
525-
value_ndarray.shape != value_shape or
523+
not node_ndarray.flags.c_contiguous):
524+
raise ValueError('Did not recognise loaded array layout for `node_ndarray`')
525+
526+
value_shape = (node_ndarray.shape[0], self.n_outputs, self.max_n_classes)
527+
if (value_ndarray.shape != value_shape or
526528
not value_ndarray.flags.c_contiguous or
527529
value_ndarray.dtype != np.float64):
528-
raise ValueError('Did not recognise loaded array layout')
530+
raise ValueError('Did not recognise loaded array layout for `value_ndarray`')
529531

530532
self.capacity = node_ndarray.shape[0]
531533
if self._resize_c(self.capacity) != 0:
@@ -541,15 +543,15 @@ cdef class Tree:
541543
if (jac_ndarray.shape != jac_shape or
542544
not jac_ndarray.flags.c_contiguous or
543545
jac_ndarray.dtype != np.float64):
544-
raise ValueError('Did not recognise loaded array layout')
546+
raise ValueError('Did not recognise loaded array layout for `jac_ndarray`')
545547
jac = memcpy(self.jac, (<np.ndarray> jac_ndarray).data,
546548
self.capacity * self.jac_stride * sizeof(double))
547549
precond_ndarray = d['precond']
548550
precond_shape = (node_ndarray.shape[0], self.n_outputs)
549551
if (precond_ndarray.shape != precond_shape or
550552
not precond_ndarray.flags.c_contiguous or
551553
precond_ndarray.dtype != np.float64):
552-
raise ValueError('Did not recognise loaded array layout')
554+
raise ValueError('Did not recognise loaded array layout for `precond_ndarray`')
553555
precond = memcpy(self.precond, (<np.ndarray> precond_ndarray).data,
554556
self.capacity * self.precond_stride * sizeof(double))
555557

@@ -917,7 +919,7 @@ cdef class Tree:
917919
cdef np.npy_intp shape[3]
918920
shape[0] = <np.npy_intp> self.node_count
919921
shape[1] = <np.npy_intp> self.n_outputs
920-
shape[2] = 1
922+
shape[2] = <np.npy_intp> self.max_n_classes
921923
cdef np.ndarray arr
922924
arr = np.PyArray_SimpleNewFromData(3, shape, np.NPY_DOUBLE, self.value)
923925
Py_INCREF(self)

0 commit comments

Comments
 (0)