2222class DoubleML (SampleSplittingMixin , ABC ):
2323 """Double Machine Learning."""
2424
25- def __init__ (self , obj_dml_data , n_folds , n_rep , score , draw_sample_splitting ):
25+ def __init__ (self , obj_dml_data , n_folds , n_rep , score , draw_sample_splitting , double_sample_splitting = False ):
2626 # check and pick up obj_dml_data
2727 if not isinstance (obj_dml_data , DoubleMLBaseData ):
2828 raise TypeError (
@@ -34,18 +34,10 @@ def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting):
3434 if obj_dml_data .n_cluster_vars > 2 :
3535 raise NotImplementedError ("Multi-way (n_ways > 2) clustering not yet implemented." )
3636 self ._is_cluster_data = True
37- self ._is_panel_data = False
38- if isinstance (obj_dml_data , DoubleMLPanelData ):
39- self ._is_panel_data = True
40- self ._is_did_data = False
41- if isinstance (obj_dml_data , DoubleMLDIDData ):
42- self ._is_did_data = True
43- self ._is_ssm_data = False
44- if isinstance (obj_dml_data , DoubleMLSSMData ):
45- self ._is_ssm_data = True
46- self ._is_rdd_data = False
47- if isinstance (obj_dml_data , DoubleMLRDDData ):
48- self ._is_rdd_data = True
37+ self ._is_panel_data = isinstance (obj_dml_data , DoubleMLPanelData )
38+ self ._is_did_data = isinstance (obj_dml_data , DoubleMLDIDData )
39+ self ._is_ssm_data = isinstance (obj_dml_data , DoubleMLSSMData )
40+ self ._is_rdd_data = isinstance (obj_dml_data , DoubleMLRDDData )
4941
5042 self ._dml_data = obj_dml_data
5143 self ._n_obs = self ._dml_data .n_obs
@@ -108,6 +100,9 @@ def __init__(self, obj_dml_data, n_folds, n_rep, score, draw_sample_splitting):
108100 self ._smpls = None
109101 self ._smpls_cluster = None
110102 self ._n_obs_sample_splitting = self .n_obs
103+ self ._double_sample_splitting = double_sample_splitting
104+ if self ._double_sample_splitting :
105+ self ._smpls_inner = None
111106 if draw_sample_splitting :
112107 self .draw_sample_splitting ()
113108 self ._score_dim = (self ._dml_data .n_obs , self .n_rep , self ._dml_data .n_coefs )
@@ -359,6 +354,21 @@ def smpls(self):
359354 raise ValueError (err_msg )
360355 return self ._smpls
361356
357+ @property
358+ def smpls_inner (self ):
359+ """
360+ The partition used for cross-fitting.
361+ """
362+ if not self ._double_sample_splitting :
363+ raise ValueError ("smpls_inner is only available for double sample splitting." )
364+ if self ._smpls_inner is None :
365+ err_msg = (
366+ "Sample splitting not specified. Either draw samples via .draw_sample splitting() "
367+ + "or set external samples via .set_sample_splitting()."
368+ )
369+ raise ValueError (err_msg )
370+ return self ._smpls_inner
371+
362372 @property
363373 def smpls_cluster (self ):
364374 """
@@ -507,6 +517,18 @@ def summary(self):
507517 def __smpls (self ):
508518 return self ._smpls [self ._i_rep ]
509519
520+ @property
521+ def __smpls__inner (self ):
522+ if not self ._double_sample_splitting :
523+ raise ValueError ("smpls_inner is only available for double sample splitting." )
524+ if self ._smpls_inner is None :
525+ err_msg = (
526+ "Sample splitting not specified. Either draw samples via .draw_sample splitting() "
527+ + "or set external samples via .set_sample_splitting()."
528+ )
529+ raise ValueError (err_msg )
530+ return self ._smpls_inner [self ._i_rep ]
531+
510532 @property
511533 def __smpls_cluster (self ):
512534 return self ._smpls_cluster [self ._i_rep ]
@@ -1081,7 +1103,10 @@ def _initalize_fit(self, store_predictions, store_models):
10811103
10821104 def _fit_nuisance_and_score_elements (self , n_jobs_cv , store_predictions , external_predictions , store_models ):
10831105 ext_prediction_dict = _set_external_predictions (
1084- external_predictions , learners = self .params_names , treatment = self ._dml_data .d_cols [self ._i_treat ], i_rep = self ._i_rep
1106+ external_predictions ,
1107+ learners = self .params_names ,
1108+ treatment = self ._dml_data .d_cols [self ._i_treat ],
1109+ i_rep = self ._i_rep ,
10851110 )
10861111
10871112 # ml estimation of nuisance models and computation of score elements
@@ -1230,7 +1255,7 @@ def evaluate_learners(self, learners=None, metric=_rmse):
12301255 >>> def mae(y_true, y_pred):
12311256 ... subset = np.logical_not(np.isnan(y_true))
12321257 ... return mean_absolute_error(y_true[subset], y_pred[subset])
1233- >>> dml_irm_obj.evaluate_learners(metric=mae)
1258+ >>> dml_irm_obj.evaluate_learners(metric=mae) # doctest: +SKIP
12341259 {'ml_g0': array([[0.88173585]]), 'ml_g1': array([[0.83854057]]), 'ml_m': array([[0.35871235]])}
12351260 """
12361261 # if no learners are provided try to evaluate all learners
@@ -1249,12 +1274,19 @@ def evaluate_learners(self, learners=None, metric=_rmse):
12491274 for learner in learners :
12501275 for rep in range (self .n_rep ):
12511276 for coef_idx in range (self ._dml_data .n_coefs ):
1252- res = metric (
1253- y_pred = self .predictions [learner ][:, rep , coef_idx ].reshape (1 , - 1 ),
1254- y_true = self .nuisance_targets [learner ][:, rep , coef_idx ].reshape (1 , - 1 ),
1255- )
1256- if not np .isfinite (res ):
1257- raise ValueError (f"Evaluation from learner { str (learner )} is not finite." )
1277+ targets = self .nuisance_targets [learner ][:, rep , coef_idx ].reshape (1 , - 1 )
1278+
1279+ if np .all (np .isnan (targets )):
1280+ res = np .nan
1281+ else :
1282+ predictions = self .predictions [learner ][:, rep , coef_idx ].reshape (1 , - 1 )
1283+ res = metric (
1284+ y_pred = predictions ,
1285+ y_true = targets ,
1286+ )
1287+ if not np .isfinite (res ):
1288+ raise ValueError (f"Evaluation from learner { str (learner )} is not finite." )
1289+
12581290 dist [learner ][rep , coef_idx ] = res
12591291 return dist
12601292 else :
0 commit comments