Skip to content

Commit ba9cc57

Browse files
authored
Merge pull request #239 from DoubleML/s-ext-pred-benchmark
Enable external predictions for short model in benchmark
2 parents 8bc3d94 + 3769f81 commit ba9cc57

File tree

3 files changed

+84
-4
lines changed

3 files changed

+84
-4
lines changed

doubleml/double_ml.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1735,7 +1735,7 @@ def sensitivity_plot(self, idx_treatment=0, value='theta', include_scenario=True
17351735
fill=fill)
17361736
return fig
17371737

1738-
def sensitivity_benchmark(self, benchmarking_set):
1738+
def sensitivity_benchmark(self, benchmarking_set, fit_args=None):
17391739
"""
17401740
Computes a benchmark for a given set of features.
17411741
Returns a DataFrame containing the corresponding values for cf_y, cf_d, rho and the change in estimates.
@@ -1757,12 +1757,18 @@ def sensitivity_benchmark(self, benchmarking_set):
17571757
if not set(benchmarking_set) <= set(x_list_long):
17581758
raise ValueError(f"benchmarking_set must be a subset of features {str(self._dml_data.x_cols)}. "
17591759
f'{str(benchmarking_set)} was passed.')
1760+
if fit_args is not None and not isinstance(fit_args, dict):
1761+
raise TypeError('fit_args must be a dict. '
1762+
f'{str(fit_args)} of type {type(fit_args)} was passed.')
17601763

17611764
# refit short form of the model
17621765
x_list_short = [x for x in x_list_long if x not in benchmarking_set]
17631766
dml_short = copy.deepcopy(self)
17641767
dml_short._dml_data.x_cols = x_list_short
1765-
dml_short.fit()
1768+
if fit_args is not None:
1769+
dml_short.fit(**fit_args)
1770+
else:
1771+
dml_short.fit()
17661772

17671773
benchmark_dict = gain_statistics(dml_long=self, dml_short=dml_short)
17681774
df_benchmark = pd.DataFrame(benchmark_dict, index=self._dml_data.d_cols)

doubleml/tests/test_exceptions_ext_preds.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
2-
from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLData
2+
from doubleml import DoubleMLCVAR, DoubleMLQTE, DoubleMLIRM, DoubleMLData
33
from doubleml.datasets import make_irm_data
44
from doubleml.utils import DMLDummyRegressor, DMLDummyClassifier
55

6+
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
7+
68
df_irm = make_irm_data(n_obs=10, dim_x=2, theta=0.5, return_type="DataFrame")
79
ext_predictions = {"d": {}}
810

@@ -21,3 +23,14 @@ def test_qte_external_prediction_exception():
2123
with pytest.raises(NotImplementedError, match=msg):
2224
qte = DoubleMLQTE(DoubleMLData(df_irm, "y", "d"), DMLDummyClassifier(), DMLDummyClassifier())
2325
qte.fit(external_predictions=ext_predictions)
26+
27+
28+
@pytest.mark.ci
29+
def test_sensitivity_benchmark_external_prediction_exception():
30+
msg = "fit_args must be a dict. "
31+
with pytest.raises(TypeError, match=msg):
32+
fit_args = []
33+
irm = DoubleMLIRM(DoubleMLData(df_irm, "y", "d"), RandomForestRegressor(), RandomForestClassifier())
34+
irm.fit()
35+
irm.sensitivity_analysis()
36+
irm.sensitivity_benchmark(benchmarking_set=["X1"], fit_args=fit_args)

doubleml/tests/test_sensitivity.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
import pytest
22
import numpy as np
3+
import copy
34

45
import doubleml as dml
5-
from sklearn.linear_model import LinearRegression
6+
from doubleml import DoubleMLIRM, DoubleMLData
7+
from doubleml.datasets import make_irm_data
8+
from sklearn.linear_model import LinearRegression, LogisticRegression
69

710
from ._utils_doubleml_sensitivity_manual import doubleml_sensitivity_manual, \
811
doubleml_sensitivity_benchmark_manual
912

1013

14+
@pytest.fixture(scope="module", params=[["X1"], ["X2"], ["X3"]])
15+
def benchmarking_set(request):
16+
return request.param
17+
18+
1119
@pytest.fixture(scope='module',
1220
params=[1, 3])
1321
def n_rep(request):
@@ -99,3 +107,56 @@ def test_dml_sensitivity_benchmark(dml_sensitivity_multitreat_fixture):
99107
assert all(dml_sensitivity_multitreat_fixture['benchmark'].index ==
100108
dml_sensitivity_multitreat_fixture['d_cols'])
101109
assert dml_sensitivity_multitreat_fixture['benchmark'].equals(dml_sensitivity_multitreat_fixture['benchmark_manual'])
110+
111+
112+
@pytest.fixture(scope="module")
113+
def test_dml_benchmark_fixture(benchmarking_set, n_rep):
114+
random_state = 42
115+
x, y, d = make_irm_data(n_obs=50, dim_x=5, theta=0, return_type="np.array")
116+
117+
classifier_class = LogisticRegression
118+
regressor_class = LinearRegression
119+
120+
np.random.seed(3141)
121+
dml_data = DoubleMLData.from_arrays(x=x, y=y, d=d)
122+
x_list_long = copy.deepcopy(dml_data.x_cols)
123+
dml_int = DoubleMLIRM(dml_data,
124+
ml_m=classifier_class(random_state=random_state),
125+
ml_g=regressor_class(),
126+
n_folds=2,
127+
n_rep=n_rep)
128+
dml_int.fit(store_predictions=True)
129+
dml_int.sensitivity_analysis()
130+
dml_ext = copy.deepcopy(dml_int)
131+
df_bm = dml_int.sensitivity_benchmark(benchmarking_set=benchmarking_set)
132+
133+
np.random.seed(3141)
134+
dml_data_short = DoubleMLData.from_arrays(x=x, y=y, d=d)
135+
dml_data_short.x_cols = [x for x in x_list_long if x not in benchmarking_set]
136+
dml_short = DoubleMLIRM(dml_data_short,
137+
ml_m=classifier_class(random_state=random_state),
138+
ml_g=regressor_class(),
139+
n_folds=2,
140+
n_rep=n_rep)
141+
dml_short.fit(store_predictions=True)
142+
fit_args = {"external_predictions": {"d": {"ml_m": dml_short.predictions["ml_m"][:, :, 0],
143+
"ml_g0": dml_short.predictions["ml_g0"][:, :, 0],
144+
"ml_g1": dml_short.predictions["ml_g1"][:, :, 0],
145+
}
146+
},
147+
}
148+
dml_ext.sensitivity_analysis()
149+
df_bm_ext = dml_ext.sensitivity_benchmark(benchmarking_set=benchmarking_set, fit_args=fit_args)
150+
151+
res_dict = {"default_benchmark": df_bm,
152+
"external_benchmark": df_bm_ext}
153+
154+
return res_dict
155+
156+
157+
@pytest.mark.ci
158+
def test_dml_sensitivity_external_predictions(test_dml_benchmark_fixture):
159+
assert np.allclose(test_dml_benchmark_fixture["default_benchmark"],
160+
test_dml_benchmark_fixture["external_benchmark"],
161+
rtol=1e-9,
162+
atol=1e-4)

0 commit comments

Comments
 (0)