@@ -50,18 +50,23 @@ class _RLearner(LinearCateEstimator):
50
50
If RandomState instance, random_state is the random number generator;
51
51
If None, the random number generator is the RandomState instance used
52
52
by `np.random`.
53
+
54
+ inference: string, inference method, or None
55
+ Method for performing inference. This estimator supports 'bootstrap'
56
+ (or an instance of `BootstrapOptions`)
53
57
"""
54
58
55
59
def __init__ (self , model_y , model_t , model_final ,
56
- discrete_treatment , n_splits , random_state ):
60
+ discrete_treatment , n_splits , random_state , inference ):
57
61
self ._models_y = [clone (model_y , safe = False ) for _ in range (n_splits )]
58
62
self ._models_t = [clone (model_t , safe = False ) for _ in range (n_splits )]
59
63
self ._model_final = clone (model_final , safe = False )
60
64
self ._n_splits = n_splits
61
65
self ._discrete_treatment = discrete_treatment
62
66
self ._random_state = check_random_state (random_state )
67
+ super ().__init__ (inference = inference )
63
68
64
- def fit (self , Y , T , X = None , W = None ):
69
+ def _fit_impl (self , Y , T , X = None , W = None ):
65
70
if X is None :
66
71
X = np .ones ((shape (Y )[0 ], 1 ))
67
72
if W is None :
@@ -203,6 +208,10 @@ class _DMLCateEstimatorBase(_RLearner):
203
208
If RandomState instance, random_state is the random number generator;
204
209
If None, the random number generator is the RandomState instance used
205
210
by `np.random`.
211
+
212
+ inference: string, inference method, or None
213
+ Method for performing inference. This estimator supports 'bootstrap'
214
+ (or an instance of `BootstrapOptions`).
206
215
"""
207
216
208
217
def __init__ (self ,
@@ -211,7 +220,8 @@ def __init__(self,
211
220
sparseLinear ,
212
221
discrete_treatment ,
213
222
n_splits ,
214
- random_state ):
223
+ random_state ,
224
+ inference ):
215
225
216
226
class FirstStageWrapper :
217
227
def __init__ (self , model , is_Y ):
@@ -274,7 +284,8 @@ def coef_(self):
274
284
model_final = FinalWrapper (),
275
285
discrete_treatment = discrete_treatment ,
276
286
n_splits = n_splits ,
277
- random_state = random_state )
287
+ random_state = random_state ,
288
+ inference = inference )
278
289
279
290
@property
280
291
def coef_ (self ):
@@ -321,22 +332,28 @@ class DMLCateEstimator(_DMLCateEstimatorBase):
321
332
If RandomState instance, random_state is the random number generator;
322
333
If None, the random number generator is the RandomState instance used
323
334
by `np.random`.
335
+
336
+ inference: string, inference method, or None
337
+ Method for performing inference. This estimator supports 'bootstrap'
338
+ (or an instance of `BootstrapOptions`)
324
339
"""
325
340
326
341
def __init__ (self ,
327
342
model_y , model_t , model_final = LinearRegression (fit_intercept = False ),
328
343
featurizer = PolynomialFeatures (degree = 1 , include_bias = True ),
329
344
discrete_treatment = False ,
330
345
n_splits = 2 ,
331
- random_state = None ):
346
+ random_state = None ,
347
+ inference = None ):
332
348
super ().__init__ (model_y = model_y ,
333
349
model_t = model_t ,
334
350
model_final = model_final ,
335
351
featurizer = featurizer ,
336
352
sparseLinear = False ,
337
353
discrete_treatment = discrete_treatment ,
338
354
n_splits = n_splits ,
339
- random_state = random_state )
355
+ random_state = random_state ,
356
+ inference = inference )
340
357
341
358
342
359
class SparseLinearDMLCateEstimator (_DMLCateEstimatorBase ):
@@ -376,22 +393,28 @@ class SparseLinearDMLCateEstimator(_DMLCateEstimatorBase):
376
393
If RandomState instance, random_state is the random number generator;
377
394
If None, the random number generator is the RandomState instance used
378
395
by `np.random`.
396
+
397
+ inference: string, inference method, or None
398
+ Method for performing inference. This estimator supports 'bootstrap'
399
+ (or an instance of `BootstrapOptions`)
379
400
"""
380
401
381
402
def __init__ (self ,
382
403
linear_model_y = LassoCV (), linear_model_t = LassoCV (), model_final = LinearRegression (fit_intercept = False ),
383
404
featurizer = PolynomialFeatures (degree = 1 , include_bias = True ),
384
405
discrete_treatment = False ,
385
406
n_splits = 2 ,
386
- random_state = None ):
407
+ random_state = None ,
408
+ inference = None ):
387
409
super ().__init__ (model_y = linear_model_y ,
388
410
model_t = linear_model_t ,
389
411
model_final = model_final ,
390
412
featurizer = featurizer ,
391
413
sparseLinear = True ,
392
414
discrete_treatment = discrete_treatment ,
393
415
n_splits = n_splits ,
394
- random_state = random_state )
416
+ random_state = random_state ,
417
+ inference = inference )
395
418
396
419
397
420
class KernelDMLCateEstimator (DMLCateEstimator ):
@@ -421,15 +444,19 @@ class KernelDMLCateEstimator(DMLCateEstimator):
421
444
n_splits: int, optional (default is 2)
422
445
The number of splits to use when fitting the first-stage models.
423
446
424
- random_state: int, RandomState instance or None, optional (default=None)
447
+ random_state: int, RandomState instance or None, optional (default=None)
425
448
If int, random_state is the seed used by the random number generator;
426
449
If RandomState instance, random_state is the random number generator;
427
450
If None, the random number generator is the RandomState instance used
428
451
by `np.random`.
429
- """
452
+
453
+ inference: string, inference method, or None
454
+ Method for performing inference. This estimator supports 'bootstrap'
455
+ (or an instance of `BootstrapOptions`)
456
+ """
430
457
431
458
def __init__ (self , model_y , model_t , model_final = LinearRegression (fit_intercept = False ),
432
- dim = 20 , bw = 1.0 , n_splits = 2 , random_state = None ):
459
+ dim = 20 , bw = 1.0 , n_splits = 2 , random_state = None , inference = None ):
433
460
class RandomFeatures (TransformerMixin ):
434
461
def fit (innerself , X ):
435
462
innerself .omegas = self ._random_state .normal (0 , 1 / bw , size = (shape (X )[1 ], dim ))
@@ -440,4 +467,5 @@ def transform(innerself, X):
440
467
return np .sqrt (2 / dim ) * np .cos (np .matmul (X , innerself .omegas ) + innerself .biases )
441
468
442
469
super ().__init__ (model_y = model_y , model_t = model_t , model_final = model_final ,
443
- featurizer = RandomFeatures (), n_splits = n_splits , random_state = random_state )
470
+ featurizer = RandomFeatures (), n_splits = n_splits , random_state = random_state ,
471
+ inference = inference )
0 commit comments