Skip to content

Commit e268dca

Browse files
authored
Merge pull request #110 from DoubleML/m-check-smpls
improve exception handling for externally provided sample splitting
2 parents a78b745 + eb3ab2e commit e268dca

File tree

3 files changed

+84
-16
lines changed

3 files changed

+84
-16
lines changed

doubleml/_helper.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,32 @@ def _check_is_partition(smpls, n_obs):
3434
return True
3535

3636

37-
def _check_all_smpls(all_smpls, n_obs):
37+
def _check_all_smpls(all_smpls, n_obs, check_intersect=False):
3838
all_smpls_checked = list()
3939
for smpl in all_smpls:
40-
this_smpl_checked = list()
41-
for tpl in smpl:
42-
this_smpl_checked.append(_check_smpl_split_tpl(tpl, n_obs))
43-
all_smpls_checked.append(this_smpl_checked)
40+
all_smpls_checked.append(_check_smpl_split(smpl, n_obs, check_intersect))
4441
return all_smpls_checked
4542

4643

47-
def _check_smpl_split_tpl(smpl, n_obs):
48-
train_index = np.sort(np.array(smpl[0]))
49-
test_index = np.sort(np.array(smpl[1]))
44+
def _check_smpl_split(smpl, n_obs, check_intersect=False):
45+
smpl_checked = list()
46+
for tpl in smpl:
47+
smpl_checked.append(_check_smpl_split_tpl(tpl, n_obs, check_intersect))
48+
return smpl_checked
49+
50+
51+
def _check_smpl_split_tpl(tpl, n_obs, check_intersect=False):
52+
train_index = np.sort(np.array(tpl[0]))
53+
test_index = np.sort(np.array(tpl[1]))
5054

5155
if not issubclass(train_index.dtype.type, np.integer):
5256
raise TypeError('Invalid sample split. Train indices must be of type integer.')
5357
if not issubclass(test_index.dtype.type, np.integer):
5458
raise TypeError('Invalid sample split. Test indices must be of type integer.')
5559

56-
if set(train_index) & set(test_index):
57-
raise ValueError('Invalid sample split. Intersection of train and test indices is not empty.')
60+
if check_intersect:
61+
if set(train_index) & set(test_index):
62+
raise ValueError('Invalid sample split. Intersection of train and test indices is not empty.')
5863

5964
if len(np.unique(train_index)) != len(train_index):
6065
raise ValueError('Invalid sample split. Train indices contain non-unique entries.')

doubleml/double_ml.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from .double_ml_data import DoubleMLData
1414
from .double_ml_resampling import DoubleMLResampling
15-
from ._helper import _check_is_partition, _check_all_smpls, _draw_weights
15+
from ._helper import _check_is_partition, _check_all_smpls, _check_smpl_split, _check_smpl_split_tpl, _draw_weights
1616

1717

1818
class DoubleML(ABC):
@@ -1010,6 +1010,7 @@ def set_sample_splitting(self, all_smpls):
10101010
if not len(all_smpls) == 2:
10111011
raise ValueError('Invalid partition provided. '
10121012
'Tuple for train_ind and test_ind must consist of exactly two elements.')
1013+
all_smpls = _check_smpl_split_tpl(all_smpls, self._dml_data.n_obs)
10131014
if (_check_is_partition([all_smpls], self._dml_data.n_obs) &
10141015
_check_is_partition([(all_smpls[1], all_smpls[0])], self._dml_data.n_obs)):
10151016
self._n_rep = 1
@@ -1020,7 +1021,7 @@ def set_sample_splitting(self, all_smpls):
10201021
self._n_rep = 1
10211022
self._n_folds = 2
10221023
self._apply_cross_fitting = False
1023-
self._smpls = _check_all_smpls([[all_smpls]], self._dml_data.n_obs)
1024+
self._smpls = _check_all_smpls([[all_smpls]], self._dml_data.n_obs, check_intersect=True)
10241025
else:
10251026
if not isinstance(all_smpls, list):
10261027
raise TypeError('all_smpls must be of list or tuple type. '
@@ -1031,6 +1032,7 @@ def set_sample_splitting(self, all_smpls):
10311032
raise ValueError('Invalid partition provided. '
10321033
'All tuples for train_ind and test_ind must consist of exactly two elements.')
10331034
self._n_rep = 1
1035+
all_smpls = _check_smpl_split(all_smpls, self._dml_data.n_obs)
10341036
if _check_is_partition(all_smpls, self._dml_data.n_obs):
10351037
if ((len(all_smpls) == 1) &
10361038
_check_is_partition([(all_smpls[0][1], all_smpls[0][0])], self._dml_data.n_obs)):
@@ -1040,14 +1042,14 @@ def set_sample_splitting(self, all_smpls):
10401042
else:
10411043
self._n_folds = len(all_smpls)
10421044
self._apply_cross_fitting = True
1043-
self._smpls = _check_all_smpls([all_smpls], self._dml_data.n_obs)
1045+
self._smpls = _check_all_smpls([all_smpls], self._dml_data.n_obs, check_intersect=True)
10441046
else:
10451047
if not len(all_smpls) == 1:
10461048
raise ValueError('Invalid partition provided. '
10471049
'Tuples for more than one fold provided that don\'t form a partition.')
10481050
self._n_folds = 2
10491051
self._apply_cross_fitting = False
1050-
self._smpls = _check_all_smpls([all_smpls], self._dml_data.n_obs)
1052+
self._smpls = _check_all_smpls([all_smpls], self._dml_data.n_obs, check_intersect=True)
10511053
else:
10521054
all_list = all([isinstance(smpl, list) for smpl in all_smpls])
10531055
if not all_list:
@@ -1065,6 +1067,7 @@ def set_sample_splitting(self, all_smpls):
10651067
if not np.all(n_folds_each_smpl == n_folds_each_smpl[0]):
10661068
raise ValueError('Invalid partition provided. '
10671069
'Different number of folds for repeated sample splitting.')
1070+
all_smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs)
10681071
smpls_are_partitions = [_check_is_partition(smpl, self._dml_data.n_obs) for smpl in all_smpls]
10691072

10701073
if all(smpls_are_partitions):
@@ -1078,7 +1081,7 @@ def set_sample_splitting(self, all_smpls):
10781081
self._n_rep = len(all_smpls)
10791082
self._n_folds = n_folds_each_smpl[0]
10801083
self._apply_cross_fitting = True
1081-
self._smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs)
1084+
self._smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs, check_intersect=True)
10821085
else:
10831086
if not n_folds_each_smpl[0] == 1:
10841087
raise ValueError('Invalid partition provided. '
@@ -1087,7 +1090,7 @@ def set_sample_splitting(self, all_smpls):
10871090
self._n_rep = len(all_smpls)
10881091
self._n_folds = 2
10891092
self._apply_cross_fitting = False
1090-
self._smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs)
1093+
self._smpls = _check_all_smpls(all_smpls, self._dml_data.n_obs, check_intersect=True)
10911094

10921095
self._psi, self._psi_a, self._psi_b, \
10931096
self._coef, self._se, self._all_coef, self._all_se, self._all_dml1_coef = self._initialize_arrays()

doubleml/tests/test_doubleml_set_sample_splitting.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,63 @@ def test_doubleml_draw_vs_set():
215215
n_folds=2, n_rep=4, apply_cross_fitting=False)
216216
dml_plr_set.set_sample_splitting(dml_plr_drawn.smpls)
217217
_assert_resampling_pars(dml_plr_drawn, dml_plr_set)
218+
219+
220+
@pytest.mark.ci
221+
def test_doubleml_set_sample_splitting_invalid_sets():
222+
# sample splitting with two folds and repeated cross-fitting with n_rep = 2
223+
smpls = [[([0, 1.2, 2, 3, 4], [5, 6, 7, 8, 9]),
224+
([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])],
225+
[([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]),
226+
([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
227+
msg = 'Invalid sample split. Train indices must be of type integer.'
228+
with pytest.raises(TypeError, match=msg):
229+
dml_plr.set_sample_splitting(smpls)
230+
231+
smpls = [[([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
232+
([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])],
233+
[([0, 2, 4, 6, 8], [1, 3.5, 5, 7, 9]),
234+
([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
235+
msg = 'Invalid sample split. Test indices must be of type integer.'
236+
with pytest.raises(TypeError, match=msg):
237+
dml_plr.set_sample_splitting(smpls)
238+
239+
smpls = [[([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
240+
([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])],
241+
[([0, 2, 3, 4, 6, 8], [1, 3, 5, 7, 9]),
242+
([1, 5, 7, 9], [0, 2, 4, 6, 8])]]
243+
msg = 'Invalid sample split. Intersection of train and test indices is not empty.'
244+
with pytest.raises(ValueError, match=msg):
245+
dml_plr.set_sample_splitting(smpls)
246+
247+
smpls = [[([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
248+
([5, 6, 7, 7, 8, 9], [0, 1, 2, 3, 4])],
249+
[([0, 2, 4, 4, 6, 8], [1, 3, 5, 7, 9]),
250+
([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
251+
msg = 'Invalid sample split. Train indices contain non-unique entries.'
252+
with pytest.raises(ValueError, match=msg):
253+
dml_plr.set_sample_splitting(smpls)
254+
255+
smpls = [[([0, 1, 2, 3, 4], [5, 5, 6, 7, 8, 9]),
256+
([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])],
257+
[([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]),
258+
([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
259+
msg = 'Invalid sample split. Test indices contain non-unique entries.'
260+
with pytest.raises(ValueError, match=msg):
261+
dml_plr.set_sample_splitting(smpls)
262+
263+
smpls = [[([0, 1, 2, 3, 20], [5, 6, 7, 8, 9]),
264+
([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])],
265+
[([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]),
266+
([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
267+
msg = r'Invalid sample split. Train indices must be in \[0, n_obs\).'
268+
with pytest.raises(ValueError, match=msg):
269+
dml_plr.set_sample_splitting(smpls)
270+
271+
smpls = [[([0, 1, 2, 3, 4], [5, -6, 7, 8, 9]),
272+
([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])],
273+
[([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]),
274+
([1, 3, 5, 7, 9], [0, 2, 4, 6, 8])]]
275+
msg = r'Invalid sample split. Test indices must be in \[0, n_obs\).'
276+
with pytest.raises(ValueError, match=msg):
277+
dml_plr.set_sample_splitting(smpls)

0 commit comments

Comments
 (0)