|
4 | 4 |
|
5 | 5 | from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union
|
6 | 6 |
|
| 7 | + |
7 | 8 | try:
|
8 | 9 | from typing import Self # type: ignore[attr-defined]
|
9 | 10 | except ImportError:
|
|
17 | 18 | from sklearn.base import clone
|
18 | 19 | from sklearn.pipeline import Pipeline as _Pipeline
|
19 | 20 | from sklearn.pipeline import _final_estimator_has, _fit_transform_one
|
20 |
| -from sklearn.utils import Bunch, _print_elapsed_time |
| 21 | +from sklearn.utils import Bunch |
21 | 22 | from sklearn.utils.metadata_routing import (
|
22 | 23 | _routing_enabled, # pylint: disable=protected-access
|
23 | 24 | )
|
|
32 | 33 | PostPredictionTransformation,
|
33 | 34 | PostPredictionWrapper,
|
34 | 35 | )
|
| 36 | +from molpipeline.utils.logging import print_elapsed_time |
35 | 37 | from molpipeline.utils.molpipeline_types import (
|
36 | 38 | AnyElement,
|
37 | 39 | AnyPredictor,
|
@@ -240,7 +242,7 @@ def _fit(
|
240 | 242 | for step in self._iter(with_final=False, filter_passthrough=False):
|
241 | 243 | step_idx, name, transformer = step
|
242 | 244 | if transformer is None or transformer == "passthrough":
|
243 |
| - with _print_elapsed_time("Pipeline", self._log_message(step_idx)): |
| 245 | + with print_elapsed_time("Pipeline", self._log_message(step_idx)): |
244 | 246 | continue
|
245 | 247 |
|
246 | 248 | if hasattr(memory, "location") and memory.location is None:
|
@@ -457,7 +459,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
|
457 | 459 | """
|
458 | 460 | routed_params = self._check_method_params(method="fit", props=fit_params)
|
459 | 461 | Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name
|
460 |
| - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): |
| 462 | + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): |
461 | 463 | if self._final_estimator != "passthrough":
|
462 | 464 | if is_empty(Xt):
|
463 | 465 | logger.warning(
|
@@ -528,7 +530,7 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
|
528 | 530 | routed_params = self._check_method_params(method="fit_transform", props=params)
|
529 | 531 | iter_input, iter_label = self._fit(X, y, routed_params)
|
530 | 532 | last_step = self._final_estimator
|
531 |
| - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): |
| 533 | + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): |
532 | 534 | if last_step == "passthrough":
|
533 | 535 | pass
|
534 | 536 | elif is_empty(iter_input):
|
@@ -648,7 +650,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
|
648 | 650 | ) # pylint: disable=invalid-name
|
649 | 651 |
|
650 | 652 | params_last_step = routed_params[self.steps[-1][0]]
|
651 |
| - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): |
| 653 | + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): |
652 | 654 | if self._final_estimator == "passthrough":
|
653 | 655 | y_pred = iter_input
|
654 | 656 | elif is_empty(iter_input):
|
|
0 commit comments