Skip to content

Commit ec189d6

Browse files
committed
Enable internal bootstrap on all estimators
1 parent dbf6b5e commit ec189d6

8 files changed

+278
-48
lines changed

econml/cate_estimator.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,63 @@
55

66
import abc
77
import numpy as np
8+
from .bootstrap import BootstrapEstimator
9+
from .inference import BootstrapOptions
810
from .utilities import tensordot, ndim, reshape, shape
911

1012

11-
class BaseCateEstimator:
12-
"""Base class for all CATE estimators in this package."""
13+
class BaseCateEstimator(metaclass=abc.ABCMeta):
14+
"""
15+
Base class for all CATE estimators in this package.
16+
17+
Parameters
18+
----------
19+
inference: string, inference method, or None
20+
Method for performing inference. All estimators support 'bootstrap'
21+
(or an instance of `BootstrapOptions`), some support other methods as well.
22+
23+
"""
24+
25+
_inference_options = {'bootstrap': BootstrapOptions()}
26+
_bootstrap_whitelist = {'effect', 'marginal_effect'}
27+
28+
@abc.abstractmethod
29+
def __init__(self, inference):
30+
"""
31+
Initialize the estimator.
32+
33+
All subclass overrides should complete with a call to this method on the super class,
34+
since it enables bootstrapping.
35+
36+
"""
37+
if inference in self._inference_options:
38+
inference = self._inference_options[inference]
39+
if isinstance(inference, BootstrapOptions):
40+
# Note that fit (and other methods) check for the presence of a _bootstrap attribute
41+
# to determine whether to delegate to that object or not;
42+
# The clones wrapped inside the BootstrapEstimator will not have that attribute since
43+
# it's assigned *after* creating the estimator
44+
self._bootstrap = BootstrapEstimator(self, inference.n_bootstrap_samples, inference.n_jobs)
45+
self._inference = inference
46+
47+
def __getattr__(self, name):
48+
suffix = '_interval'
49+
if name.endswith(suffix) and name[: - len(suffix)] in self._bootstrap_whitelist:
50+
if hasattr(self, '_bootstrap'):
51+
return getattr(self._bootstrap, name)
52+
else:
53+
raise AttributeError('\'%s\' object does not support attribute \'%s\'; '
54+
'consider passing inference=\'bootstrap\' when initializing'
55+
% (type(self).__name__, name))
56+
else:
57+
raise AttributeError('\'%s\' object has no attribute \'%s\''
58+
% (type(self).__name__, name))
1359

1460
@abc.abstractmethod
15-
def fit(self, Y, T, X=None, W=None, Z=None):
61+
def _fit_impl(self, Y, T, X=None, W=None, Z=None):
62+
pass
63+
64+
def fit(self, *args, **kwargs):
1665
"""
1766
Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·).
1867
@@ -37,7 +86,9 @@ def fit(self, Y, T, X=None, W=None, Z=None):
3786
self
3887
3988
"""
40-
pass
89+
if hasattr(self, '_bootstrap'):
90+
self._bootstrap.fit(*args, **kwargs)
91+
return self._fit_impl(*args, **kwargs)
4192

4293
@abc.abstractmethod
4394
def effect(self, X=None, T0=0, T1=1):
@@ -92,7 +143,20 @@ def marginal_effect(self, T, X=None):
92143

93144

94145
class LinearCateEstimator(BaseCateEstimator):
95-
"""Base class for all CATE estimators with linear treatment effects in this package."""
146+
"""
147+
Base class for all CATE estimators with linear treatment effects in this package.
148+
149+
Parameters
150+
----------
151+
inference: string, inference method, or None
152+
Method for performing inference. All estimators support 'bootstrap'
153+
(or an instance of `BootstrapOptions`), some support other methods as well.
154+
155+
"""
156+
157+
@abc.abstractmethod
158+
def __init__(self, inference):
159+
super().__init__(inference=inference)
96160

97161
@abc.abstractmethod
98162
def const_marginal_effect(self, X=None):

econml/deepiv.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,19 @@ class DeepIVEstimator(BaseCateEstimator):
268268
second_stage_options : dictionary, optional
269269
The keyword arguments to pass to Keras's `fit` method when training the second stage model.
270270
Defaults to `{"epochs": 100}`.
271+
272+
inference: string, inference method, or None
273+
Method for performing inference. This estimator supports 'bootstrap'
274+
(or an instance of `BootstrapOptions`)
275+
271276
"""
272277

273278
def __init__(self, n_components, m, h,
274279
n_samples, use_upper_bound_loss=False, n_gradient_samples=0,
275280
optimizer='adam',
276281
first_stage_options={"epochs": 100},
277-
second_stage_options={"epochs": 100}):
282+
second_stage_options={"epochs": 100},
283+
inference=None):
278284
self._n_components = n_components
279285
self._m = m
280286
self._h = h
@@ -284,8 +290,9 @@ def __init__(self, n_components, m, h,
284290
self._optimizer = optimizer
285291
self._first_stage_options = first_stage_options
286292
self._second_stage_options = second_stage_options
293+
super().__init__(inference=inference)
287294

288-
def fit(self, Y, T, X, Z):
295+
def _fit_impl(self, Y, T, X, Z):
289296
"""Estimate the counterfactual model from data.
290297
291298
That is, estimate functions τ(·, ·, ·), ∂τ(·, ·).

econml/dml.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,23 @@ class _RLearner(LinearCateEstimator):
5050
If RandomState instance, random_state is the random number generator;
5151
If None, the random number generator is the RandomState instance used
5252
by `np.random`.
53+
54+
inference: string, inference method, or None
55+
Method for performing inference. This estimator supports 'bootstrap'
56+
(or an instance of `BootstrapOptions`)
5357
"""
5458

5559
def __init__(self, model_y, model_t, model_final,
56-
discrete_treatment, n_splits, random_state):
60+
discrete_treatment, n_splits, random_state, inference):
5761
self._models_y = [clone(model_y, safe=False) for _ in range(n_splits)]
5862
self._models_t = [clone(model_t, safe=False) for _ in range(n_splits)]
5963
self._model_final = clone(model_final, safe=False)
6064
self._n_splits = n_splits
6165
self._discrete_treatment = discrete_treatment
6266
self._random_state = check_random_state(random_state)
67+
super().__init__(inference=inference)
6368

64-
def fit(self, Y, T, X=None, W=None):
69+
def _fit_impl(self, Y, T, X=None, W=None):
6570
if X is None:
6671
X = np.ones((shape(Y)[0], 1))
6772
if W is None:
@@ -203,6 +208,10 @@ class _DMLCateEstimatorBase(_RLearner):
203208
If RandomState instance, random_state is the random number generator;
204209
If None, the random number generator is the RandomState instance used
205210
by `np.random`.
211+
212+
inference: string, inference method, or None
213+
Method for performing inference. This estimator supports 'bootstrap'
214+
(or an instance of `BootstrapOptions`).
206215
"""
207216

208217
def __init__(self,
@@ -211,7 +220,8 @@ def __init__(self,
211220
sparseLinear,
212221
discrete_treatment,
213222
n_splits,
214-
random_state):
223+
random_state,
224+
inference):
215225

216226
class FirstStageWrapper:
217227
def __init__(self, model, is_Y):
@@ -274,7 +284,8 @@ def coef_(self):
274284
model_final=FinalWrapper(),
275285
discrete_treatment=discrete_treatment,
276286
n_splits=n_splits,
277-
random_state=random_state)
287+
random_state=random_state,
288+
inference=inference)
278289

279290
@property
280291
def coef_(self):
@@ -321,22 +332,28 @@ class DMLCateEstimator(_DMLCateEstimatorBase):
321332
If RandomState instance, random_state is the random number generator;
322333
If None, the random number generator is the RandomState instance used
323334
by `np.random`.
335+
336+
inference: string, inference method, or None
337+
Method for performing inference. This estimator supports 'bootstrap'
338+
(or an instance of `BootstrapOptions`)
324339
"""
325340

326341
def __init__(self,
327342
model_y, model_t, model_final=LinearRegression(fit_intercept=False),
328343
featurizer=PolynomialFeatures(degree=1, include_bias=True),
329344
discrete_treatment=False,
330345
n_splits=2,
331-
random_state=None):
346+
random_state=None,
347+
inference=None):
332348
super().__init__(model_y=model_y,
333349
model_t=model_t,
334350
model_final=model_final,
335351
featurizer=featurizer,
336352
sparseLinear=False,
337353
discrete_treatment=discrete_treatment,
338354
n_splits=n_splits,
339-
random_state=random_state)
355+
random_state=random_state,
356+
inference=inference)
340357

341358

342359
class SparseLinearDMLCateEstimator(_DMLCateEstimatorBase):
@@ -376,22 +393,28 @@ class SparseLinearDMLCateEstimator(_DMLCateEstimatorBase):
376393
If RandomState instance, random_state is the random number generator;
377394
If None, the random number generator is the RandomState instance used
378395
by `np.random`.
396+
397+
inference: string, inference method, or None
398+
Method for performing inference. This estimator supports 'bootstrap'
399+
(or an instance of `BootstrapOptions`)
379400
"""
380401

381402
def __init__(self,
382403
linear_model_y=LassoCV(), linear_model_t=LassoCV(), model_final=LinearRegression(fit_intercept=False),
383404
featurizer=PolynomialFeatures(degree=1, include_bias=True),
384405
discrete_treatment=False,
385406
n_splits=2,
386-
random_state=None):
407+
random_state=None,
408+
inference=None):
387409
super().__init__(model_y=linear_model_y,
388410
model_t=linear_model_t,
389411
model_final=model_final,
390412
featurizer=featurizer,
391413
sparseLinear=True,
392414
discrete_treatment=discrete_treatment,
393415
n_splits=n_splits,
394-
random_state=random_state)
416+
random_state=random_state,
417+
inference=inference)
395418

396419

397420
class KernelDMLCateEstimator(DMLCateEstimator):
@@ -421,15 +444,19 @@ class KernelDMLCateEstimator(DMLCateEstimator):
421444
n_splits: int, optional (default is 2)
422445
The number of splits to use when fitting the first-stage models.
423446
424-
random_state: int, RandomState instance or None, optional (default=None)
447+
random_state: int, RandomState instance or None, optional (default=None)
425448
If int, random_state is the seed used by the random number generator;
426449
If RandomState instance, random_state is the random number generator;
427450
If None, the random number generator is the RandomState instance used
428451
by `np.random`.
429-
"""
452+
453+
inference: string, inference method, or None
454+
Method for performing inference. This estimator supports 'bootstrap'
455+
(or an instance of `BootstrapOptions`)
456+
"""
430457

431458
def __init__(self, model_y, model_t, model_final=LinearRegression(fit_intercept=False),
432-
dim=20, bw=1.0, n_splits=2, random_state=None):
459+
dim=20, bw=1.0, n_splits=2, random_state=None, inference=None):
433460
class RandomFeatures(TransformerMixin):
434461
def fit(innerself, X):
435462
innerself.omegas = self._random_state.normal(0, 1 / bw, size=(shape(X)[1], dim))
@@ -440,4 +467,5 @@ def transform(innerself, X):
440467
return np.sqrt(2 / dim) * np.cos(np.matmul(X, innerself.omegas) + innerself.biases)
441468

442469
super().__init__(model_y=model_y, model_t=model_t, model_final=model_final,
443-
featurizer=RandomFeatures(), n_splits=n_splits, random_state=random_state)
470+
featurizer=RandomFeatures(), n_splits=n_splits, random_state=random_state,
471+
inference=inference)

econml/inference.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
"""Options for performing inference in estimators."""
5+
6+
7+
class BootstrapOptions:
8+
"""
9+
Wrapper storing bootstrap options.
10+
11+
This class can be used for inference with any CATE estimator.
12+
13+
Parameters
14+
----------
15+
n_bootstrap_samples : int, optional (default 100)
16+
How many draws to perform.
17+
18+
n_jobs: int, optional (default -1)
19+
The maximum number of concurrently running jobs, as in joblib.Parallel.
20+
21+
"""
22+
23+
def __init__(self, n_bootstrap_samples=100, n_jobs=-1):
24+
self._n_bootstrap_samples = n_bootstrap_samples
25+
self._n_jobs = n_jobs
26+
27+
@property
28+
def n_bootstrap_samples(self):
29+
"""Get how many draws to perform."""
30+
return self._n_bootstrap_samples
31+
32+
@property
33+
def n_jobs(self):
34+
"""Get the maximum number of concurrently running jobs, as in joblib.Parallel."""
35+
return self._n_jobs

0 commit comments

Comments
 (0)