Skip to content

Commit db1e254

Browse files
authored
Optimizing NormalInferenceResults confidence interval method speed (#879)
* Fixed normal inference results confidence interval unnecessary loop Signed-off-by: gdaiha <[email protected]> Signed-off-by: Gabriel Daiha <[email protected]>
1 parent 2065758 commit db1e254

File tree

2 files changed

+13
-30
lines changed

2 files changed

+13
-30
lines changed

econml/inference/_inference.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,14 +1066,9 @@ def conf_int(self, alpha=0.05):
10661066
"""
10671067
if self.stderr is None:
10681068
raise AttributeError("Only point estimates are available!")
1069-
if np.isscalar(self.point_estimate):
1069+
else:
10701070
return _safe_norm_ppf(alpha / 2, loc=self.point_estimate, scale=self.stderr), \
10711071
_safe_norm_ppf(1 - alpha / 2, loc=self.point_estimate, scale=self.stderr)
1072-
else:
1073-
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
1074-
for p, err in zip(self.point_estimate, self.stderr)]), \
1075-
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
1076-
for p, err in zip(self.point_estimate, self.stderr)])
10771072

10781073
def pvalue(self, value=0):
10791074
"""
@@ -1398,14 +1393,8 @@ def conf_int_mean(self, *, alpha=None):
13981393
alpha = self.alpha if alpha is None else alpha
13991394
mean_point = self.mean_point
14001395
stderr_mean = self.stderr_mean
1401-
if np.isscalar(mean_point):
1402-
return (_safe_norm_ppf(alpha / 2, loc=mean_point, scale=stderr_mean),
1403-
_safe_norm_ppf(1 - alpha / 2, loc=mean_point, scale=stderr_mean))
1404-
else:
1405-
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
1406-
for p, err in zip(mean_point, stderr_mean)]), \
1407-
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
1408-
for p, err in zip(mean_point, stderr_mean)])
1396+
return (_safe_norm_ppf(alpha / 2, loc=mean_point, scale=stderr_mean),
1397+
_safe_norm_ppf(1 - alpha / 2, loc=mean_point, scale=stderr_mean))
14091398

14101399
@property
14111400
def std_point(self):

econml/sklearn_extensions/linear_model.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,10 +1627,8 @@ def coef__interval(self, alpha=0.05):
16271627
coef__interval : {tuple ((p, d) array, (p,d) array), tuple ((d,) array, (d,) array)}
16281628
The lower and upper bounds of the confidence interval of the coefficients
16291629
"""
1630-
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
1631-
for p, err in zip(self.coef_, self.coef_stderr_)]), \
1632-
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
1633-
for p, err in zip(self.coef_, self.coef_stderr_)])
1630+
return (_safe_norm_ppf(alpha / 2, loc=self.coef_, scale=self.coef_stderr_),
1631+
_safe_norm_ppf(1 - alpha / 2, loc=self.coef_, scale=self.coef_stderr_))
16341632

16351633
def intercept__interval(self, alpha=0.05):
16361634
"""
@@ -1651,14 +1649,8 @@ def intercept__interval(self, alpha=0.05):
16511649
return (0 if self._n_out == 0 else np.zeros(self._n_out)), \
16521650
(0 if self._n_out == 0 else np.zeros(self._n_out))
16531651

1654-
if self._n_out == 0:
1655-
return _safe_norm_ppf(alpha / 2, loc=self.intercept_, scale=self.intercept_stderr_), \
1656-
_safe_norm_ppf(1 - alpha / 2, loc=self.intercept_, scale=self.intercept_stderr_)
1657-
else:
1658-
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
1659-
for p, err in zip(self.intercept_, self.intercept_stderr_)]), \
1660-
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
1661-
for p, err in zip(self.intercept_, self.intercept_stderr_)])
1652+
return (_safe_norm_ppf(alpha / 2, loc=self.intercept_, scale=self.intercept_stderr_),
1653+
_safe_norm_ppf(1 - alpha / 2, loc=self.intercept_, scale=self.intercept_stderr_))
16621654

16631655
def predict_interval(self, X, alpha=0.05):
16641656
"""
@@ -1677,10 +1669,12 @@ def predict_interval(self, X, alpha=0.05):
16771669
prediction_intervals : {tuple ((n,) array, (n,) array), tuple ((n,p) array, (n,p) array)}
16781670
The lower and upper bounds of the confidence intervals of the predicted mean outcomes
16791671
"""
1680-
return np.array([_safe_norm_ppf(alpha / 2, loc=p, scale=err)
1681-
for p, err in zip(self.predict(X), self.prediction_stderr(X))]), \
1682-
np.array([_safe_norm_ppf(1 - alpha / 2, loc=p, scale=err)
1683-
for p, err in zip(self.predict(X), self.prediction_stderr(X))])
1672+
1673+
pred = self.predict(X)
1674+
pred_stderr = self.prediction_stderr(X)
1675+
1676+
return (_safe_norm_ppf(alpha / 2, loc=pred, scale=pred_stderr),
1677+
_safe_norm_ppf(1 - alpha / 2, loc=pred, scale=pred_stderr))
16841678

16851679

16861680
class StatsModelsLinearRegression(_StatsModelsWrapper):

0 commit comments

Comments
 (0)