Skip to content

Commit 77fca18

Browse files
Merge pull request #14 from basf/adapt_own_time_elapsed_from_sklearn
utils: add own _print_elapsed_time
2 parents f3c7007 + b4cb32a commit 77fca18

File tree

3 files changed

+122
-5
lines changed

3 files changed

+122
-5
lines changed

molpipeline/pipeline/_skl_pipeline.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union
66

7+
78
try:
89
from typing import Self # type: ignore[attr-defined]
910
except ImportError:
@@ -17,7 +18,7 @@
1718
from sklearn.base import clone
1819
from sklearn.pipeline import Pipeline as _Pipeline
1920
from sklearn.pipeline import _final_estimator_has, _fit_transform_one
20-
from sklearn.utils import Bunch, _print_elapsed_time
21+
from sklearn.utils import Bunch
2122
from sklearn.utils.metadata_routing import (
2223
_routing_enabled, # pylint: disable=protected-access
2324
)
@@ -32,6 +33,7 @@
3233
PostPredictionTransformation,
3334
PostPredictionWrapper,
3435
)
36+
from molpipeline.utils.logging import print_elapsed_time
3537
from molpipeline.utils.molpipeline_types import (
3638
AnyElement,
3739
AnyPredictor,
@@ -240,7 +242,7 @@ def _fit(
240242
for step in self._iter(with_final=False, filter_passthrough=False):
241243
step_idx, name, transformer = step
242244
if transformer is None or transformer == "passthrough":
243-
with _print_elapsed_time("Pipeline", self._log_message(step_idx)):
245+
with print_elapsed_time("Pipeline", self._log_message(step_idx)):
244246
continue
245247

246248
if hasattr(memory, "location") and memory.location is None:
@@ -457,7 +459,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
457459
"""
458460
routed_params = self._check_method_params(method="fit", props=fit_params)
459461
Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name
460-
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
462+
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
461463
if self._final_estimator != "passthrough":
462464
if is_empty(Xt):
463465
logger.warning(
@@ -528,7 +530,7 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
528530
routed_params = self._check_method_params(method="fit_transform", props=params)
529531
iter_input, iter_label = self._fit(X, y, routed_params)
530532
last_step = self._final_estimator
531-
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
533+
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
532534
if last_step == "passthrough":
533535
pass
534536
elif is_empty(iter_input):
@@ -648,7 +650,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
648650
) # pylint: disable=invalid-name
649651

650652
params_last_step = routed_params[self.steps[-1][0]]
651-
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
653+
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
652654
if self._final_estimator == "passthrough":
653655
y_pred = iter_input
654656
elif is_empty(iter_input):

molpipeline/utils/logging.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Logging helper functions."""
2+
3+
from __future__ import annotations
4+
5+
import timeit
6+
from contextlib import contextmanager
7+
from typing import Generator
8+
9+
from loguru import logger
10+
11+
12+
def _message_with_time(source: str, message: str, time: float) -> str:
13+
"""Create one line message for logging purposes.
14+
15+
Adapted from sklearn's function to stay consistent with the logging style:
16+
https://github.com/scikit-learn/scikit-learn/blob/e16a6ddebd527e886fc22105710ee20ce255f9f0/sklearn/utils/_user_interface.py
17+
18+
Parameters
19+
----------
20+
source : str
21+
String indicating the source or the reference of the message.
22+
message : str
23+
Short message.
24+
time : float
25+
Time in seconds.
26+
27+
Returns
28+
-------
29+
str
30+
Message with elapsed time.
31+
"""
32+
start_message = f"[{source}] "
33+
34+
# adapted from joblib.logger.short_format_time without the Windows -.1s
35+
# adjustment
36+
if time > 60:
37+
time_str = f"{(time / 60):4.1f}min"
38+
else:
39+
time_str = f" {time:5.1f}s"
40+
41+
end_message = f" {message}, total={time_str}"
42+
dots_len = 70 - len(start_message) - len(end_message)
43+
return f"{start_message}{dots_len * '.'}{end_message}"
44+
45+
46+
@contextmanager
47+
def print_elapsed_time(
48+
source: str, message: str | None = None, use_logger: bool = False
49+
) -> Generator[None, None, None]:
50+
"""Log elapsed time to stdout when the context is exited.
51+
52+
Adapted from sklearn's function to stay consistent with the logging style:
53+
https://github.com/scikit-learn/scikit-learn/blob/e16a6ddebd527e886fc22105710ee20ce255f9f0/sklearn/utils/_user_interface.py
54+
55+
Parameters
56+
----------
57+
source : str
58+
String indicating the source or the reference of the message.
59+
message : str, default=None
60+
Short message. If None, nothing will be printed.
61+
use_logger : bool, default=False
62+
If True, the message will be logged using the logger.
63+
64+
Returns
65+
-------
66+
context_manager
67+
Prints elapsed time upon exit if verbose.
68+
"""
69+
if message is None:
70+
yield
71+
else:
72+
start = timeit.default_timer()
73+
yield
74+
message_to_print = _message_with_time(
75+
source, message, timeit.default_timer() - start
76+
)
77+
78+
if use_logger:
79+
logger.info(message_to_print)
80+
else:
81+
print(message_to_print)

tests/test_utils/test_logging.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Test logging utils."""
2+
3+
import io
4+
import unittest
5+
from contextlib import redirect_stdout
6+
7+
from molpipeline.utils.logging import print_elapsed_time
8+
9+
10+
class LoggingUtilsTest(unittest.TestCase):
11+
"""Unittest for conversion of sklearn models to json and back."""
12+
13+
def test__print_elapsed_time(self) -> None:
14+
"""Test message logging with timings work as expected."""
15+
16+
# when message is None nothing should be printed
17+
stream1 = io.StringIO()
18+
with redirect_stdout(stream1):
19+
with print_elapsed_time("source", message=None, use_logger=False):
20+
pass
21+
output1 = stream1.getvalue()
22+
self.assertEqual(output1, "")
23+
24+
# message should be printed in the expected sklearn format
25+
stream2 = io.StringIO()
26+
with redirect_stdout(stream2):
27+
with print_elapsed_time("source", message="my message", use_logger=False):
28+
pass
29+
output2 = stream2.getvalue()
30+
self.assertTrue(
31+
output2.startswith(
32+
"[source] ................................... my message, total="
33+
)
34+
)

0 commit comments

Comments
 (0)