8
8
9
9
import joblib
10
10
import lightgbm as lgb
11
+ from numba .core .utils import erase_traceback
11
12
import numpy as np
12
13
from numpy .lib .function_base import iterable
13
14
import pandas as pd
14
- from sklearn .base import TransformerMixin
15
+ from sklearn .base import BaseEstimator , TransformerMixin
15
16
from sklearn .compose import ColumnTransformer
16
17
from sklearn .ensemble import GradientBoostingClassifier , RandomForestClassifier , RandomForestRegressor
17
18
from sklearn .linear_model import Lasso , LassoCV , LogisticRegression , LogisticRegressionCV
@@ -172,7 +173,7 @@ def _first_stage_clf(X, y, *, make_regressor=False, automl=True, min_count=None,
172
173
else :
173
174
model = LogisticRegressionCV (
174
175
cv = min (5 , min_count ), max_iter = 1000 , Cs = cs , random_state = random_state ).fit (X , y )
175
- est = LogisticRegression (C = model .C_ [0 ], random_state = random_state )
176
+ est = LogisticRegression (C = model .C_ [0 ], max_iter = 1000 , random_state = random_state )
176
177
if make_regressor :
177
178
return _RegressionWrapper (est )
178
179
else :
@@ -192,8 +193,6 @@ def _final_stage(*, random_state=None, verbose=0):
192
193
193
194
# simplification of sklearn's ColumnTransformer that encodes categoricals and passes through selected other columns
194
195
# but also supports get_feature_names with expected signature
195
-
196
-
197
196
class _ColumnTransformer (TransformerMixin ):
198
197
def __init__ (self , categorical , passthrough ):
199
198
self .categorical = categorical
@@ -208,22 +207,16 @@ def fit(self, X):
208
207
handle_unknown = 'ignore' ).fit (cat_cols )
209
208
else :
210
209
self .has_cats = False
211
- cont_cols = _safe_indexing (X , self .passthrough , axis = 1 )
212
- if cont_cols .shape [1 ] > 0 :
213
- self .has_conts = True
214
- self .scaler = StandardScaler ().fit (cont_cols )
215
- else :
216
- self .has_conts = False
217
210
self .d_x = X .shape [1 ]
218
211
return self
219
212
220
213
def transform (self , X ):
221
214
rest = _safe_indexing (X , self .passthrough , axis = 1 )
222
- if self .has_conts :
223
- rest = self .scaler .transform (rest )
224
215
if self .has_cats :
225
216
cats = self .one_hot_encoder .transform (_safe_indexing (X , self .categorical , axis = 1 ))
226
- return np .hstack ((cats , rest ))
217
+ # NOTE: we rely on the passthrough columns coming first in the concatenated X;W
218
+ # when we pipeline scaling with our first stage models later, so the order here is important
219
+ return np .hstack ((rest , cats ))
227
220
else :
228
221
return rest
229
222
@@ -234,11 +227,32 @@ def get_feature_names(self, names=None):
234
227
if self .has_cats :
235
228
cats = self .one_hot_encoder .get_feature_names (
236
229
_safe_indexing (names , self .categorical , axis = 0 ))
237
- return np .concatenate ((cats , rest ))
230
+ return np .concatenate ((rest , cats ))
238
231
else :
239
232
return rest
240
233
241
234
235
+ # Wrapper to make sure that we get a deep copy of the contents instead of clone returning an untrained copy
236
+ class _Wrapper :
237
+ def __init__ (self , item ):
238
+ self .item = item
239
+
240
+
241
+ class _FrozenTransformer (TransformerMixin , BaseEstimator ):
242
+ def __init__ (self , wrapper ):
243
+ self .wrapper = wrapper
244
+
245
+ def fit (self , X , y ):
246
+ return self
247
+
248
+ def transform (self , X ):
249
+ return self .wrapper .item .transform (X )
250
+
251
+
252
+ def _freeze (transformer ):
253
+ return _FrozenTransformer (_Wrapper (transformer ))
254
+
255
+
242
256
# Convert python objects to (possibly nested) types that can easily be represented as literals
243
257
def _sanitize (obj ):
244
258
if obj is None or isinstance (obj , (bool , int , str , float )):
@@ -310,6 +324,13 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
310
324
else :
311
325
cats = 'auto' # just leave the setting at the default otherwise
312
326
327
+ # the transformation logic here is somewhat tricky; we always need to encode the categorical columns,
328
+ # whether they end up in X or in W. However, for the continuous columns, we want to scale them all
329
+ # when running the first stage models, but don't want to scale the X columns when running the final model,
330
+ # since then our coefficients will have odd units and our trees will also have decisions using those units.
331
+ #
332
+ # we achieve this by pipelining the X scaling with the Y and T models (with fixed scaling, not refitting)
333
+
313
334
hinds = heterogeneity_inds [feat_ind ]
314
335
WX_transformer = ColumnTransformer ([('encode' , OneHotEncoder (drop = 'first' , sparse = False ),
315
336
[ind for ind in categorical_inds
@@ -322,11 +343,14 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
322
343
('drop' , 'drop' , hinds ),
323
344
('drop_feat' , 'drop' , feat_ind )],
324
345
remainder = StandardScaler ())
346
+
347
+ X_cont_inds = [ind for ind in hinds
348
+ if ind != feat_ind and ind not in categorical_inds ]
349
+
325
350
# Use _ColumnTransformer instead of ColumnTransformer so we can get feature names
326
351
X_transformer = _ColumnTransformer ([ind for ind in categorical_inds
327
352
if ind != feat_ind and ind in hinds ],
328
- [ind for ind in hinds
329
- if ind != feat_ind and ind not in categorical_inds ])
353
+ X_cont_inds )
330
354
331
355
# Controls are all other columns of X
332
356
WX = WX_transformer .fit_transform (X )
@@ -340,6 +364,20 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
340
364
341
365
W = W_transformer .fit_transform (X )
342
366
X_xf = X_transformer .fit_transform (X )
367
+
368
+ # HACK: this is slightly ugly because we rely on the fact that DML passes [X;W] to the first stage models
369
+ # and so we can just peel the first columns off of that combined array for rescaling in the pipeline
370
+ # TODO: consider addding an API to DML that allows for better understanding of how the nuisance inputs are
371
+ # built, such as model_y_feature_names, model_t_feature_names, model_y_transformer, etc., so that this
372
+ # becomes a valid approach to handling this
373
+ X_scaler = ColumnTransformer ([('scale' , StandardScaler (),
374
+ list (range (len (X_cont_inds ))))],
375
+ remainder = 'passthrough' ).fit (np .hstack ([X_xf , W ])).named_transformers_ ['scale' ]
376
+
377
+ X_scaler_fixed = ColumnTransformer ([('scale' , _freeze (X_scaler ),
378
+ list (range (len (X_cont_inds ))))],
379
+ remainder = 'passthrough' )
380
+
343
381
if W .shape [1 ] == 0 :
344
382
# array checking routines don't accept 0-width arrays
345
383
W = None
@@ -358,14 +396,20 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
358
396
random_state = random_state ,
359
397
verbose = verbose ))
360
398
399
+ pipelined_model_t = Pipeline ([('scale' , X_scaler_fixed ),
400
+ ('model' , model_t )])
401
+
402
+ pipelined_model_y = Pipeline ([('scale' , X_scaler_fixed ),
403
+ ('model' , model_y )])
404
+
361
405
if X_xf is None and h_model == 'forest' :
362
406
warnings .warn (f"Using a linear model instead of a forest model for feature '{ name } ' "
363
407
"because forests don't support models with no heterogeneity indices" )
364
408
h_model = 'linear'
365
409
366
410
if h_model == 'linear' :
367
- est = LinearDML (model_y = model_y ,
368
- model_t = model_t ,
411
+ est = LinearDML (model_y = pipelined_model_y ,
412
+ model_t = pipelined_model_t ,
369
413
discrete_treatment = discrete_treatment ,
370
414
fit_cate_intercept = True ,
371
415
linear_first_stages = False ,
@@ -374,8 +418,8 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
374
418
cv = cv ,
375
419
mc_iters = mc_iters )
376
420
elif h_model == 'forest' :
377
- est = CausalForestDML (model_y = model_y ,
378
- model_t = model_t ,
421
+ est = CausalForestDML (model_y = pipelined_model_y ,
422
+ model_t = pipelined_model_t ,
379
423
discrete_treatment = discrete_treatment ,
380
424
n_estimators = 4000 ,
381
425
min_var_leaf_on_val = True ,
0 commit comments