|
1 | 1 | import pytest
|
2 | 2 | import numpy as np
|
| 3 | +import copy |
3 | 4 |
|
4 | 5 | import doubleml as dml
|
5 |
| -from sklearn.linear_model import LinearRegression |
| 6 | +from doubleml import DoubleMLIRM, DoubleMLData |
| 7 | +from doubleml.datasets import make_irm_data |
| 8 | +from sklearn.linear_model import LinearRegression, LogisticRegression |
6 | 9 |
|
7 | 10 | from ._utils_doubleml_sensitivity_manual import doubleml_sensitivity_manual, \
|
8 | 11 | doubleml_sensitivity_benchmark_manual
|
9 | 12 |
|
10 | 13 |
|
| 14 | +@pytest.fixture(scope="module", params=[["X1"], ["X2"], ["X3"]]) |
| 15 | +def benchmarking_set(request): |
| 16 | + return request.param |
| 17 | + |
| 18 | + |
11 | 19 | @pytest.fixture(scope='module',
|
12 | 20 | params=[1, 3])
|
13 | 21 | def n_rep(request):
|
@@ -99,3 +107,56 @@ def test_dml_sensitivity_benchmark(dml_sensitivity_multitreat_fixture):
|
99 | 107 | assert all(dml_sensitivity_multitreat_fixture['benchmark'].index ==
|
100 | 108 | dml_sensitivity_multitreat_fixture['d_cols'])
|
101 | 109 | assert dml_sensitivity_multitreat_fixture['benchmark'].equals(dml_sensitivity_multitreat_fixture['benchmark_manual'])
|
| 110 | + |
| 111 | + |
| 112 | +@pytest.fixture(scope="module") |
| 113 | +def test_dml_benchmark_fixture(benchmarking_set, n_rep): |
| 114 | + random_state = 42 |
| 115 | + x, y, d = make_irm_data(n_obs=50, dim_x=5, theta=0, return_type="np.array") |
| 116 | + |
| 117 | + classifier_class = LogisticRegression |
| 118 | + regressor_class = LinearRegression |
| 119 | + |
| 120 | + np.random.seed(3141) |
| 121 | + dml_data = DoubleMLData.from_arrays(x=x, y=y, d=d) |
| 122 | + x_list_long = copy.deepcopy(dml_data.x_cols) |
| 123 | + dml_int = DoubleMLIRM(dml_data, |
| 124 | + ml_m=classifier_class(random_state=random_state), |
| 125 | + ml_g=regressor_class(), |
| 126 | + n_folds=2, |
| 127 | + n_rep=n_rep) |
| 128 | + dml_int.fit(store_predictions=True) |
| 129 | + dml_int.sensitivity_analysis() |
| 130 | + dml_ext = copy.deepcopy(dml_int) |
| 131 | + df_bm = dml_int.sensitivity_benchmark(benchmarking_set=benchmarking_set) |
| 132 | + |
| 133 | + np.random.seed(3141) |
| 134 | + dml_data_short = DoubleMLData.from_arrays(x=x, y=y, d=d) |
| 135 | + dml_data_short.x_cols = [x for x in x_list_long if x not in benchmarking_set] |
| 136 | + dml_short = DoubleMLIRM(dml_data_short, |
| 137 | + ml_m=classifier_class(random_state=random_state), |
| 138 | + ml_g=regressor_class(), |
| 139 | + n_folds=2, |
| 140 | + n_rep=n_rep) |
| 141 | + dml_short.fit(store_predictions=True) |
| 142 | + fit_args = {"external_predictions": {"d": {"ml_m": dml_short.predictions["ml_m"][:, :, 0], |
| 143 | + "ml_g0": dml_short.predictions["ml_g0"][:, :, 0], |
| 144 | + "ml_g1": dml_short.predictions["ml_g1"][:, :, 0], |
| 145 | + } |
| 146 | + }, |
| 147 | + } |
| 148 | + dml_ext.sensitivity_analysis() |
| 149 | + df_bm_ext = dml_ext.sensitivity_benchmark(benchmarking_set=benchmarking_set, fit_args=fit_args) |
| 150 | + |
| 151 | + res_dict = {"default_benchmark": df_bm, |
| 152 | + "external_benchmark": df_bm_ext} |
| 153 | + |
| 154 | + return res_dict |
| 155 | + |
| 156 | + |
| 157 | +@pytest.mark.ci |
| 158 | +def test_dml_sensitivity_external_predictions(test_dml_benchmark_fixture): |
| 159 | + assert np.allclose(test_dml_benchmark_fixture["default_benchmark"], |
| 160 | + test_dml_benchmark_fixture["external_benchmark"], |
| 161 | + rtol=1e-9, |
| 162 | + atol=1e-4) |
0 commit comments