@@ -109,7 +109,7 @@ def fit(self, Y, T, *, X, inference=None):
109
109
self .models [ind ].fit (X [T == ind ], Y [T == ind ])
110
110
111
111
def const_marginal_effect (self , X ):
112
- """Calculate the constant marignal treatment effect on a vector of features for each sample.
112
+ """Calculate the constant marginal treatment effect on a vector of features for each sample.
113
113
114
114
Parameters
115
115
----------
@@ -127,7 +127,11 @@ def const_marginal_effect(self, X):
127
127
X = check_array (X )
128
128
taus = []
129
129
for ind in range (self ._d_t [0 ]):
130
- taus .append (self .models [ind + 1 ].predict (X ) - self .models [0 ].predict (X ))
130
+ if hasattr (self .models [ind + 1 ], 'predict_proba' ):
131
+ taus .append (self .models [ind + 1 ].predict_proba (X )[:, 1 ] - self .models [0 ].predict_proba (X )[:, 1 ])
132
+ else :
133
+ taus .append (self .models [ind + 1 ].predict (X ) - self .models [0 ].predict (X ))
134
+
131
135
taus = np .column_stack (taus ).reshape ((- 1 ,) + self ._d_t + self ._d_y ) # shape as of m*d_t*d_y
132
136
if self ._d_y :
133
137
taus = transpose (taus , (0 , 2 , 1 )) # shape as of m*d_y*d_t
@@ -242,7 +246,12 @@ def const_marginal_effect(self, X=None):
242
246
X = check_array (X )
243
247
Xs , Ts = broadcast_unit_treatments (X , self ._d_t [0 ] + 1 )
244
248
feat_arr = np .concatenate ((Xs , Ts ), axis = 1 )
245
- prediction = self .overall_model .predict (feat_arr ).reshape ((- 1 , self ._d_t [0 ] + 1 ,) + self ._d_y )
249
+
250
+ if hasattr (self .overall_model , 'predict_proba' ):
251
+ prediction = self .overall_model .predict_proba (feat_arr )[:, 1 ].reshape ((- 1 , self ._d_t [0 ] + 1 ,) + self ._d_y )
252
+ else :
253
+ prediction = self .overall_model .predict (feat_arr ).reshape ((- 1 , self ._d_t [0 ] + 1 ,) + self ._d_y )
254
+
246
255
if self ._d_y :
247
256
prediction = transpose (prediction , (0 , 2 , 1 ))
248
257
taus = (prediction - np .repeat (prediction [:, :, 0 ], self ._d_t [0 ] + 1 ).reshape (prediction .shape ))[:, :, 1 :]
@@ -393,8 +402,14 @@ def const_marginal_effect(self, X):
393
402
taus = []
394
403
for ind in range (self ._d_t [0 ]):
395
404
propensity_scores = self .propensity_models [ind ].predict_proba (X )[:, 1 :]
396
- tau_hat = propensity_scores * self .cate_controls_models [ind ].predict (X ).reshape (m , - 1 ) \
397
- + (1 - propensity_scores ) * self .cate_treated_models [ind ].predict (X ).reshape (m , - 1 )
405
+
406
+ if hasattr (self .cate_controls_models [ind ], 'predict_proba' ):
407
+ tau_hat = propensity_scores * self .cate_controls_models [ind ].predict_proba (X )[:, 1 ].reshape (m , - 1 ) \
408
+ + (1 - propensity_scores ) * self .cate_treated_models [ind ].predict_proba (X )[:, 1 ].reshape (m , - 1 )
409
+ else :
410
+ tau_hat = propensity_scores * self .cate_controls_models [ind ].predict (X ).reshape (m , - 1 ) \
411
+ + (1 - propensity_scores ) * self .cate_treated_models [ind ].predict (X ).reshape (m , - 1 )
412
+
398
413
taus .append (tau_hat )
399
414
taus = np .column_stack (taus ).reshape ((- 1 ,) + self ._d_t + self ._d_y ) # shape as of m*d_t*d_y
400
415
if self ._d_y :
@@ -549,7 +564,10 @@ def const_marginal_effect(self, X):
549
564
X = check_array (X )
550
565
taus = []
551
566
for model in self .final_models :
552
- taus .append (model .predict (X ))
567
+ if hasattr (model , 'predict_proba' ):
568
+ taus .append (model .predict_proba (X )[:, 1 ])
569
+ else :
570
+ taus .append (model .predict (X ))
553
571
taus = np .column_stack (taus ).reshape ((- 1 ,) + self ._d_t + self ._d_y ) # shape as of m*d_t*d_y
554
572
if self ._d_y :
555
573
taus = transpose (taus , (0 , 2 , 1 )) # shape as of m*d_y*d_t
0 commit comments