Skip to content

Commit 8307f67

Browse files
committed
Fix scaling behavior
1 parent 10baa77 commit 8307f67

File tree

2 files changed

+135
-20
lines changed

2 files changed

+135
-20
lines changed

econml/solutions/causal_analysis/_causal_analysis.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import joblib
1010
import lightgbm as lgb
11+
from numba.core.utils import erase_traceback
1112
import numpy as np
1213
from numpy.lib.function_base import iterable
1314
import pandas as pd
14-
from sklearn.base import TransformerMixin
15+
from sklearn.base import BaseEstimator, TransformerMixin
1516
from sklearn.compose import ColumnTransformer
1617
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier, RandomForestRegressor
1718
from sklearn.linear_model import Lasso, LassoCV, LogisticRegression, LogisticRegressionCV
@@ -172,7 +173,7 @@ def _first_stage_clf(X, y, *, make_regressor=False, automl=True, min_count=None,
172173
else:
173174
model = LogisticRegressionCV(
174175
cv=min(5, min_count), max_iter=1000, Cs=cs, random_state=random_state).fit(X, y)
175-
est = LogisticRegression(C=model.C_[0], random_state=random_state)
176+
est = LogisticRegression(C=model.C_[0], max_iter=1000, random_state=random_state)
176177
if make_regressor:
177178
return _RegressionWrapper(est)
178179
else:
@@ -192,8 +193,6 @@ def _final_stage(*, random_state=None, verbose=0):
192193

193194
# simplification of sklearn's ColumnTransformer that encodes categoricals and passes through selected other columns
194195
# but also supports get_feature_names with expected signature
195-
196-
197196
class _ColumnTransformer(TransformerMixin):
198197
def __init__(self, categorical, passthrough):
199198
self.categorical = categorical
@@ -208,22 +207,16 @@ def fit(self, X):
208207
handle_unknown='ignore').fit(cat_cols)
209208
else:
210209
self.has_cats = False
211-
cont_cols = _safe_indexing(X, self.passthrough, axis=1)
212-
if cont_cols.shape[1] > 0:
213-
self.has_conts = True
214-
self.scaler = StandardScaler().fit(cont_cols)
215-
else:
216-
self.has_conts = False
217210
self.d_x = X.shape[1]
218211
return self
219212

220213
def transform(self, X):
221214
rest = _safe_indexing(X, self.passthrough, axis=1)
222-
if self.has_conts:
223-
rest = self.scaler.transform(rest)
224215
if self.has_cats:
225216
cats = self.one_hot_encoder.transform(_safe_indexing(X, self.categorical, axis=1))
226-
return np.hstack((cats, rest))
217+
# NOTE: we rely on the passthrough columns coming first in the concatenated X;W
218+
# when we pipeline scaling with our first stage models later, so the order here is important
219+
return np.hstack((rest, cats))
227220
else:
228221
return rest
229222

@@ -234,11 +227,32 @@ def get_feature_names(self, names=None):
234227
if self.has_cats:
235228
cats = self.one_hot_encoder.get_feature_names(
236229
_safe_indexing(names, self.categorical, axis=0))
237-
return np.concatenate((cats, rest))
230+
return np.concatenate((rest, cats))
238231
else:
239232
return rest
240233

241234

235+
# Wrapper to make sure that we get a deep copy of the contents instead of clone returning an untrained copy
236+
class _Wrapper:
237+
def __init__(self, item):
238+
self.item = item
239+
240+
241+
class _FrozenTransformer(TransformerMixin, BaseEstimator):
242+
def __init__(self, wrapper):
243+
self.wrapper = wrapper
244+
245+
def fit(self, X, y):
246+
return self
247+
248+
def transform(self, X):
249+
return self.wrapper.item.transform(X)
250+
251+
252+
def _freeze(transformer):
253+
return _FrozenTransformer(_Wrapper(transformer))
254+
255+
242256
# Convert python objects to (possibly nested) types that can easily be represented as literals
243257
def _sanitize(obj):
244258
if obj is None or isinstance(obj, (bool, int, str, float)):
@@ -310,6 +324,13 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
310324
else:
311325
cats = 'auto' # just leave the setting at the default otherwise
312326

327+
# the transformation logic here is somewhat tricky; we always need to encode the categorical columns,
328+
# whether they end up in X or in W. However, for the continuous columns, we want to scale them all
329+
# when running the first stage models, but don't want to scale the X columns when running the final model,
330+
# since then our coefficients will have odd units and our trees will also have decisions using those units.
331+
#
332+
# we achieve this by pipelining the X scaling with the Y and T models (with fixed scaling, not refitting)
333+
313334
hinds = heterogeneity_inds[feat_ind]
314335
WX_transformer = ColumnTransformer([('encode', OneHotEncoder(drop='first', sparse=False),
315336
[ind for ind in categorical_inds
@@ -322,11 +343,14 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
322343
('drop', 'drop', hinds),
323344
('drop_feat', 'drop', feat_ind)],
324345
remainder=StandardScaler())
346+
347+
X_cont_inds = [ind for ind in hinds
348+
if ind != feat_ind and ind not in categorical_inds]
349+
325350
# Use _ColumnTransformer instead of ColumnTransformer so we can get feature names
326351
X_transformer = _ColumnTransformer([ind for ind in categorical_inds
327352
if ind != feat_ind and ind in hinds],
328-
[ind for ind in hinds
329-
if ind != feat_ind and ind not in categorical_inds])
353+
X_cont_inds)
330354

331355
# Controls are all other columns of X
332356
WX = WX_transformer.fit_transform(X)
@@ -340,6 +364,20 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
340364

341365
W = W_transformer.fit_transform(X)
342366
X_xf = X_transformer.fit_transform(X)
367+
368+
# HACK: this is slightly ugly because we rely on the fact that DML passes [X;W] to the first stage models
369+
# and so we can just peel the first columns off of that combined array for rescaling in the pipeline
370+
# TODO: consider addding an API to DML that allows for better understanding of how the nuisance inputs are
371+
# built, such as model_y_feature_names, model_t_feature_names, model_y_transformer, etc., so that this
372+
# becomes a valid approach to handling this
373+
X_scaler = ColumnTransformer([('scale', StandardScaler(),
374+
list(range(len(X_cont_inds))))],
375+
remainder='passthrough').fit(np.hstack([X_xf, W])).named_transformers_['scale']
376+
377+
X_scaler_fixed = ColumnTransformer([('scale', _freeze(X_scaler),
378+
list(range(len(X_cont_inds))))],
379+
remainder='passthrough')
380+
343381
if W.shape[1] == 0:
344382
# array checking routines don't accept 0-width arrays
345383
W = None
@@ -358,14 +396,20 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
358396
random_state=random_state,
359397
verbose=verbose))
360398

399+
pipelined_model_t = Pipeline([('scale', X_scaler_fixed),
400+
('model', model_t)])
401+
402+
pipelined_model_y = Pipeline([('scale', X_scaler_fixed),
403+
('model', model_y)])
404+
361405
if X_xf is None and h_model == 'forest':
362406
warnings.warn(f"Using a linear model instead of a forest model for feature '{name}' "
363407
"because forests don't support models with no heterogeneity indices")
364408
h_model = 'linear'
365409

366410
if h_model == 'linear':
367-
est = LinearDML(model_y=model_y,
368-
model_t=model_t,
411+
est = LinearDML(model_y=pipelined_model_y,
412+
model_t=pipelined_model_t,
369413
discrete_treatment=discrete_treatment,
370414
fit_cate_intercept=True,
371415
linear_first_stages=False,
@@ -374,8 +418,8 @@ def _process_feature(name, feat_ind, verbose, categorical_inds, categories, hete
374418
cv=cv,
375419
mc_iters=mc_iters)
376420
elif h_model == 'forest':
377-
est = CausalForestDML(model_y=model_y,
378-
model_t=model_t,
421+
est = CausalForestDML(model_y=pipelined_model_y,
422+
model_t=pipelined_model_t,
379423
discrete_treatment=discrete_treatment,
380424
n_estimators=4000,
381425
min_var_leaf_on_val=True,

econml/tests/test_causal_analysis.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,74 @@ def test_invalid_inds(self):
778778
self.assertEqual(ca.trained_feature_indices_, [0, 1, 2, 3]) # can't handle last two
779779
self.assertEqual(ca.untrained_feature_indices_, [(4, 'cat_limit'),
780780
(5, 'cat_limit')])
781+
782+
# Add tests that guarantee that the reliance on DML feature order is not broken, such as
783+
# Creare a transformer that zeros out all variables after the first n_x variables, so it zeros out W
784+
# Pass an example where W is irrelevant and X is confounder
785+
# As long as DML doesnt change the order of the inputs, then things should be good. Otherwise X would be
786+
# zeroed out and the test will fail
787+
def test_scaling_transforms(self):
788+
# shouldn't matter if X is scaled much larger or much smaller than W, we should still get good estimates
789+
n = 2000
790+
X = np.random.normal(size=(n, 5))
791+
W = np.random.normal(size=(n, 5))
792+
W[:, 0] = 1 # make one of the columns a constant
793+
xt, wt, xy, wy, theta = [np.random.normal(size=sz) for sz in [(5, 1), (5, 1), (5, 1), (5, 1), (1, 1)]]
794+
T = X @ xt + W @ wt + np.random.normal(size=(n, 1))
795+
Y = X @ xy + W @ wy + T @ theta
796+
arr1 = np.hstack([X, W, T])
797+
# rescaling X shouldn't affect the first stage models because they normalize the inputs
798+
arr2 = np.hstack([1000 * X, W, T])
799+
for hmodel in ['linear', 'forest']:
800+
inds = [-1] # we just care about T
801+
cats = []
802+
hinds = list(range(X.shape[1]))
803+
ca = CausalAnalysis(inds, cats, hinds, heterogeneity_model=hmodel, random_state=123)
804+
ca.fit(arr1, Y)
805+
eff1 = ca.global_causal_effect()
806+
807+
ca.fit(arr2, Y)
808+
eff2 = ca.global_causal_effect()
809+
810+
np.testing.assert_allclose(eff1.point.values, eff2.point.values, rtol=1e-5)
811+
np.testing.assert_allclose(eff1.ci_lower.values, eff2.ci_lower.values, rtol=1e-5)
812+
np.testing.assert_allclose(eff1.ci_upper.values, eff2.ci_upper.values, rtol=1e-5)
813+
814+
np.testing.assert_allclose(eff1.point.values, theta.flatten(), rtol=1e-2)
815+
816+
# to recover individual coefficients with linear models, we need to be more careful in how we set up X to avoid
817+
# cross terms
818+
X = np.zeros(shape=(n, 5))
819+
X[range(X.shape[0]), np.random.choice(5, size=n)] = 1
820+
xt, wt, xy, wy, theta = [np.random.normal(size=sz) for sz in [(5, 1), (5, 1), (5, 1), (5, 1), (5, 1)]]
821+
T = X @ xt + W @ wt + np.random.normal(size=(n, 1))
822+
Y = X @ xy + W @ wy + T * (X @ theta)
823+
arr1 = np.hstack([X, W, T])
824+
arr2 = np.hstack([1000 * X, W, T])
825+
for hmodel in ['linear', 'forest']:
826+
inds = [-1] # we just care about T
827+
cats = []
828+
hinds = list(range(X.shape[1]))
829+
ca = CausalAnalysis(inds, cats, hinds, heterogeneity_model=hmodel, random_state=123)
830+
ca.fit(arr1, Y)
831+
eff1 = ca.global_causal_effect()
832+
loc1 = ca.local_causal_effect(
833+
np.hstack([np.eye(X.shape[1]), np.zeros((X.shape[1], arr1.shape[1] - X.shape[1]))]))
834+
ca.fit(arr2, Y)
835+
eff2 = ca.global_causal_effect()
836+
loc2 = ca.local_causal_effect(
837+
# scale by 1000 to match the input to this model:
838+
# the scale of X does matter for the final model, which keeps results in user-denominated units
839+
1000 * np.hstack([np.eye(X.shape[1]), np.zeros((X.shape[1], arr1.shape[1] - X.shape[1]))]))
840+
841+
# rescaling X still shouldn't affect the first stage models
842+
np.testing.assert_allclose(eff1.point.values, eff2.point.values, rtol=1e-5)
843+
np.testing.assert_allclose(eff1.ci_lower.values, eff2.ci_lower.values, rtol=1e-5)
844+
np.testing.assert_allclose(eff1.ci_upper.values, eff2.ci_upper.values, rtol=1e-5)
845+
846+
np.testing.assert_allclose(loc1.point.values, loc2.point.values, rtol=1e-2)
847+
848+
# TODO: we don't recover the correct values with enough accuracy to enable this assertion
849+
# is there a different way to verify that we are learning the correct coefficients?
850+
851+
# np.testing.assert_allclose(loc1.point.values, theta.flatten(), rtol=1e-1)

0 commit comments

Comments
 (0)