Skip to content

Commit a3868df

Browse files
committed
rename nuisance parts; deactivate IV-type score; #7
1 parent 1f97643 commit a3868df

File tree

4 files changed

+33
-31
lines changed

4 files changed

+33
-31
lines changed

doubleml_serverless/double_ml_pliv_aws_lambda.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ def _ml_nuisance_aws_lambda(self, cv_params):
4747

4848
payload = self._dml_data.get_payload()
4949

50-
payload_ml_g = payload.copy()
50+
payload_ml_l = payload.copy()
5151
payload_ml_m = payload.copy()
5252
payload_ml_r = payload.copy()
5353

54-
_attach_learner(payload_ml_g,
55-
'ml_g', self.learner['ml_g'],
54+
_attach_learner(payload_ml_l,
55+
'ml_l', self.learner['ml_l'],
5656
self._dml_data.y_col, self._dml_data.x_cols)
5757

5858
_attach_learner(payload_ml_m,
@@ -63,7 +63,7 @@ def _ml_nuisance_aws_lambda(self, cv_params):
6363
'ml_r', self.learner['ml_r'],
6464
self._dml_data.d_cols[0], self._dml_data.x_cols)
6565

66-
payloads = _attach_smpls([payload_ml_g, payload_ml_m, payload_ml_r],
66+
payloads = _attach_smpls([payload_ml_l, payload_ml_m, payload_ml_r],
6767
[self.smpls, self.smpls, self.smpls],
6868
self.n_folds,
6969
self.n_rep,
@@ -80,9 +80,10 @@ def _ml_nuisance_aws_lambda(self, cv_params):
8080
# compute score elements
8181
self._psi_a[:, i_rep, self._i_treat], self._psi_b[:, i_rep, self._i_treat] = \
8282
self._score_elements(y, z, d,
83-
preds['ml_g'][:, i_rep],
83+
preds['ml_l'][:, i_rep],
8484
preds['ml_m'][:, i_rep],
8585
preds['ml_r'][:, i_rep],
86+
None,
8687
self.smpls[i_rep])
8788

8889
return

doubleml_serverless/double_ml_plr_aws_lambda.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@ def _ml_nuisance_aws_lambda(self, cv_params):
4141

4242
payload = self._dml_data.get_payload()
4343

44-
payload_ml_g = payload.copy()
44+
payload_ml_l = payload.copy()
4545
payload_ml_m = payload.copy()
4646

47-
_attach_learner(payload_ml_g,
48-
'ml_g', self.learner['ml_g'],
47+
_attach_learner(payload_ml_l,
48+
'ml_l', self.learner['ml_l'],
4949
self._dml_data.y_col, self._dml_data.x_cols)
5050

5151
_attach_learner(payload_ml_m,
5252
'ml_m', self.learner['ml_m'],
5353
self._dml_data.d_cols[0], self._dml_data.x_cols)
5454

55-
payloads = _attach_smpls([payload_ml_g, payload_ml_m],
55+
payloads = _attach_smpls([payload_ml_l, payload_ml_m],
5656
[self.smpls, self.smpls],
5757
self.n_folds,
5858
self.n_rep,
@@ -69,8 +69,9 @@ def _ml_nuisance_aws_lambda(self, cv_params):
6969
# compute score elements
7070
self._psi_a[:, i_rep, self._i_treat], self._psi_b[:, i_rep, self._i_treat] = \
7171
self._score_elements(y, d,
72-
preds['ml_g'][:, i_rep],
72+
preds['ml_l'][:, i_rep],
7373
preds['ml_m'][:, i_rep],
74+
None,
7475
self.smpls[i_rep])
7576

7677
return

doubleml_serverless/tests/test_pliv.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,16 @@ def dml_pliv_fixture(generate_data_pliv, idx, learner, score, dml_procedure):
5858
x_cols = data.columns[data.columns.str.startswith('X')].tolist()
5959

6060
# Set machine learning methods for m & g
61-
ml_g = clone(learner)
61+
ml_l = clone(learner)
6262
ml_m = clone(learner)
6363
ml_r = clone(learner)
6464

6565
np.random.seed(3141)
6666
dml_data_json = dml_lambda.DoubleMLDataJson(data, 'y', ['d'], x_cols, 'Z1')
6767
dml_pliv_lambda = DoubleMLPLIVServerlessLocal('local', 'local',
6868
dml_data_json,
69-
ml_g, ml_m, ml_r,
70-
n_folds,
69+
ml_l, ml_m, ml_r,
70+
n_folds=n_folds,
7171
score=score,
7272
dml_procedure=dml_procedure)
7373

@@ -76,8 +76,8 @@ def dml_pliv_fixture(generate_data_pliv, idx, learner, score, dml_procedure):
7676
np.random.seed(3141)
7777
dml_data = dml.DoubleMLData(data, 'y', ['d'], x_cols, 'Z1')
7878
dml_pliv = dml.DoubleMLPLIV(dml_data,
79-
ml_g, ml_m, ml_r,
80-
n_folds,
79+
ml_l, ml_m, ml_r,
80+
n_folds=n_folds,
8181
score=score,
8282
dml_procedure=dml_procedure)
8383

@@ -140,7 +140,7 @@ def dml_pliv_scaling_fixture(generate_data_pliv, idx, learner, score, dml_proced
140140
x_cols = data.columns[data.columns.str.startswith('X')].tolist()
141141

142142
# Set machine learning methods for m & g
143-
ml_g = clone(learner)
143+
ml_l = clone(learner)
144144
ml_m = clone(learner)
145145
ml_r = clone(learner)
146146

@@ -149,8 +149,8 @@ def dml_pliv_scaling_fixture(generate_data_pliv, idx, learner, score, dml_proced
149149
np.random.seed(3141)
150150
dml_pliv_folds = DoubleMLPLIVServerlessLocal('local', 'local',
151151
dml_data_json,
152-
ml_g, ml_m, ml_r,
153-
n_folds,
152+
ml_l, ml_m, ml_r,
153+
n_folds=n_folds,
154154
score=score,
155155
dml_procedure=dml_procedure)
156156

@@ -159,8 +159,8 @@ def dml_pliv_scaling_fixture(generate_data_pliv, idx, learner, score, dml_proced
159159
np.random.seed(3141)
160160
dml_pliv_reps = DoubleMLPLIVServerlessLocal('local', 'local',
161161
dml_data_json,
162-
ml_g, ml_m, ml_r,
163-
n_folds,
162+
ml_l, ml_m, ml_r,
163+
n_folds=n_folds,
164164
score=score,
165165
dml_procedure=dml_procedure)
166166

doubleml_serverless/tests/test_plr.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def learner(request):
3232

3333

3434
@pytest.fixture(scope='module',
35-
params=['IV-type', 'partialling out'])
35+
params=['partialling out'])
3636
def score(request):
3737
return request.param
3838

@@ -58,15 +58,15 @@ def dml_plr_fixture(generate_data_plr, idx, learner, score, dml_procedure):
5858
x_cols = data.columns[data.columns.str.startswith('X')].tolist()
5959

6060
# Set machine learning methods for m & g
61-
ml_g = clone(learner)
61+
ml_l = clone(learner)
6262
ml_m = clone(learner)
6363

6464
np.random.seed(3141)
6565
dml_data_json = dml_lambda.DoubleMLDataJson(data, 'y', ['d'], x_cols)
6666
dml_plr_lambda = DoubleMLPLRServerlessLocal('local', 'local',
6767
dml_data_json,
68-
ml_g, ml_m,
69-
n_folds,
68+
ml_l, ml_m,
69+
n_folds=n_folds,
7070
score=score,
7171
dml_procedure=dml_procedure)
7272

@@ -75,8 +75,8 @@ def dml_plr_fixture(generate_data_plr, idx, learner, score, dml_procedure):
7575
np.random.seed(3141)
7676
dml_data = dml.DoubleMLData(data, 'y', ['d'], x_cols)
7777
dml_plr = dml.DoubleMLPLR(dml_data,
78-
ml_g, ml_m,
79-
n_folds,
78+
ml_l, ml_m,
79+
n_folds=n_folds,
8080
score=score,
8181
dml_procedure=dml_procedure)
8282

@@ -139,16 +139,16 @@ def dml_plr_scaling_fixture(generate_data_plr, idx, learner, score, dml_procedur
139139
x_cols = data.columns[data.columns.str.startswith('X')].tolist()
140140

141141
# Set machine learning methods for m & g
142-
ml_g = clone(learner)
142+
ml_l = clone(learner)
143143
ml_m = clone(learner)
144144

145145
dml_data_json = dml_lambda.DoubleMLDataJson(data, 'y', ['d'], x_cols)
146146

147147
np.random.seed(3141)
148148
dml_plr_folds = DoubleMLPLRServerlessLocal('local', 'local',
149149
dml_data_json,
150-
ml_g, ml_m,
151-
n_folds,
150+
ml_l, ml_m,
151+
n_folds=n_folds,
152152
score=score,
153153
dml_procedure=dml_procedure)
154154

@@ -157,8 +157,8 @@ def dml_plr_scaling_fixture(generate_data_plr, idx, learner, score, dml_procedur
157157
np.random.seed(3141)
158158
dml_plr_reps = DoubleMLPLRServerlessLocal('local', 'local',
159159
dml_data_json,
160-
ml_g, ml_m,
161-
n_folds,
160+
ml_l, ml_m,
161+
n_folds=n_folds,
162162
score=score,
163163
dml_procedure=dml_procedure)
164164

0 commit comments

Comments
 (0)