Skip to content

Commit 46dbbda

Browse files
committed
Add: support for predict_proba for estimators that support it
1 parent b0936ac commit 46dbbda

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

econml/metalearners/_metalearners.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def fit(self, Y, T, *, X, inference=None):
109109
self.models[ind].fit(X[T == ind], Y[T == ind])
110110

111111
def const_marginal_effect(self, X):
112-
"""Calculate the constant marignal treatment effect on a vector of features for each sample.
112+
"""Calculate the constant marginal treatment effect on a vector of features for each sample.
113113
114114
Parameters
115115
----------
@@ -127,7 +127,11 @@ def const_marginal_effect(self, X):
127127
X = check_array(X)
128128
taus = []
129129
for ind in range(self._d_t[0]):
130-
taus.append(self.models[ind + 1].predict(X) - self.models[0].predict(X))
130+
if hasattr(self.models[ind + 1], 'predict_proba'):
131+
taus.append(self.models[ind + 1].predict_proba(X)[:, 1] - self.models[0].predict_proba(X)[:, 1])
132+
else:
133+
taus.append(self.models[ind + 1].predict(X) - self.models[0].predict(X))
134+
131135
taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y
132136
if self._d_y:
133137
taus = transpose(taus, (0, 2, 1)) # shape as of m*d_y*d_t
@@ -242,7 +246,12 @@ def const_marginal_effect(self, X=None):
242246
X = check_array(X)
243247
Xs, Ts = broadcast_unit_treatments(X, self._d_t[0] + 1)
244248
feat_arr = np.concatenate((Xs, Ts), axis=1)
245-
prediction = self.overall_model.predict(feat_arr).reshape((-1, self._d_t[0] + 1,) + self._d_y)
249+
250+
if hasattr(self.overall_model, 'predict_proba'):
251+
prediction = self.overall_model.predict_proba(feat_arr)[:, 1].reshape((-1, self._d_t[0] + 1,) + self._d_y)
252+
else:
253+
prediction = self.overall_model.predict(feat_arr).reshape((-1, self._d_t[0] + 1,) + self._d_y)
254+
246255
if self._d_y:
247256
prediction = transpose(prediction, (0, 2, 1))
248257
taus = (prediction - np.repeat(prediction[:, :, 0], self._d_t[0] + 1).reshape(prediction.shape))[:, :, 1:]
@@ -393,8 +402,14 @@ def const_marginal_effect(self, X):
393402
taus = []
394403
for ind in range(self._d_t[0]):
395404
propensity_scores = self.propensity_models[ind].predict_proba(X)[:, 1:]
396-
tau_hat = propensity_scores * self.cate_controls_models[ind].predict(X).reshape(m, -1) \
397-
+ (1 - propensity_scores) * self.cate_treated_models[ind].predict(X).reshape(m, -1)
405+
406+
if hasattr(self.cate_controls_models[ind], 'predict_proba'):
407+
tau_hat = propensity_scores * self.cate_controls_models[ind].predict_proba(X)[:, 1].reshape(m, -1) \
408+
+ (1 - propensity_scores) * self.cate_treated_models[ind].predict_proba(X)[:, 1].reshape(m, -1)
409+
else:
410+
tau_hat = propensity_scores * self.cate_controls_models[ind].predict(X).reshape(m, -1) \
411+
+ (1 - propensity_scores) * self.cate_treated_models[ind].predict(X).reshape(m, -1)
412+
398413
taus.append(tau_hat)
399414
taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y
400415
if self._d_y:
@@ -549,7 +564,10 @@ def const_marginal_effect(self, X):
549564
X = check_array(X)
550565
taus = []
551566
for model in self.final_models:
552-
taus.append(model.predict(X))
567+
if hasattr(model, 'predict_proba'):
568+
taus.append(model.predict_proba(X)[:, 1])
569+
else:
570+
taus.append(model.predict(X))
553571
taus = np.column_stack(taus).reshape((-1,) + self._d_t + self._d_y) # shape as of m*d_t*d_y
554572
if self._d_y:
555573
taus = transpose(taus, (0, 2, 1)) # shape as of m*d_y*d_t

0 commit comments

Comments
 (0)