Skip to content

Commit 38bd8c8

Browse files
authored
Ensure compatibility with sklearn and implement tags (#124)
* add unit test to reproduce error * fix some metadata routing * isort + black * move compatibility tests to separate test class * ignore too many lines and remove ignores locally flagged `useless-suppression` annotations * remove ignores flagged `useless-suppression` annotations * add test for ChEMPROP Model, which fails * set tags for chemprop models * remove invalid smiles
1 parent f94c498 commit 38bd8c8

File tree

4 files changed

+354
-81
lines changed

4 files changed

+354
-81
lines changed

molpipeline/estimators/chemprop/models.py

+23
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
import numpy.typing as npt
1212
from loguru import logger
1313
from sklearn.base import clone
14+
from sklearn.utils._tags import (
15+
ClassifierTags,
16+
RegressorTags,
17+
Tags,
18+
)
1419
from sklearn.utils.metaestimators import available_if
1520

1621
try:
@@ -124,6 +129,24 @@ def _is_classifier(self) -> bool:
124129
"""
125130
return self._is_binary_classifier() or self._is_multiclass_classifier()
126131

132+
def __sklearn_tags__(self) -> Tags:
133+
"""Return the sklearn tags.
134+
135+
Returns
136+
-------
137+
Tags
138+
The sklearn tags for the model.
139+
"""
140+
tags = super().__sklearn_tags__()
141+
if self._is_classifier():
142+
tags.estimator_type = "classifier"
143+
tags.classifier_tags = ClassifierTags()
144+
else:
145+
tags.estimator_type = "regressor"
146+
tags.regressor_tags = RegressorTags()
147+
tags.target_tags.required = True
148+
return tags
149+
127150
def _predict(
128151
self, X: MoleculeDataset # pylint: disable=invalid-name
129152
) -> npt.NDArray[np.float64]:

molpipeline/pipeline/_skl_pipeline.py

+161-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Defines a pipeline is exposed to the user, accessible via pipeline."""
22

3+
# pylint: disable=too-many-lines
4+
35
from __future__ import annotations
46

7+
from copy import deepcopy
58
from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union
69

710
try:
@@ -13,15 +16,15 @@
1316
import numpy as np
1417
import numpy.typing as npt
1518
from loguru import logger
16-
from sklearn.base import _fit_context # pylint: disable=protected-access
17-
from sklearn.base import clone
19+
from sklearn.base import _fit_context, clone
1820
from sklearn.pipeline import Pipeline as _Pipeline
1921
from sklearn.pipeline import _final_estimator_has, _fit_transform_one
2022
from sklearn.utils import Bunch
23+
from sklearn.utils._tags import Tags, get_tags
2124
from sklearn.utils.metadata_routing import (
22-
_routing_enabled, # pylint: disable=protected-access
23-
)
24-
from sklearn.utils.metadata_routing import (
25+
MetadataRouter,
26+
MethodMapping,
27+
_routing_enabled,
2528
process_routing,
2629
)
2730
from sklearn.utils.metaestimators import available_if
@@ -213,8 +216,8 @@ def _final_estimator(
213216
# pylint: disable=too-many-locals,too-many-branches
214217
def _fit(
215218
self,
216-
X: Any, # pylint: disable=invalid-name
217-
y: Any = None, # pylint: disable=invalid-name
219+
X: Any,
220+
y: Any = None,
218221
routed_params: dict[str, Any] | None = None,
219222
raw_params: dict[str, Any] | None = None,
220223
) -> tuple[Any, Any]:
@@ -482,7 +485,9 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
482485
"All input rows were filtered out! Model is not fitted!"
483486
)
484487
else:
485-
fit_params_last_step = routed_params[self.steps[-1][0]]
488+
fit_params_last_step = routed_params[
489+
self._non_post_processing_steps()[-1][0]
490+
]
486491
self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"])
487492

488493
return self
@@ -552,7 +557,9 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
552557
elif is_empty(iter_input):
553558
logger.warning("All input rows were filtered out! Model is not fitted!")
554559
else:
555-
last_step_params = routed_params[self.steps[-1][0]]
560+
last_step_params = routed_params[
561+
self._non_post_processing_steps()[-1][0]
562+
]
556563
if hasattr(last_step, "fit_transform"):
557564
iter_input = last_step.fit_transform(
558565
iter_input, iter_label, **last_step_params["fit_transform"]
@@ -615,7 +622,8 @@ def predict(self, X: Any, **params: Any) -> Any:
615622
elif hasattr(self._final_estimator, "predict"):
616623
if _routing_enabled():
617624
iter_input = self._final_estimator.predict(
618-
iter_input, **routed_params[self._final_estimator].predict
625+
iter_input,
626+
**routed_params[self._non_post_processing_steps()[-1][0]].predict,
619627
)
620628
else:
621629
iter_input = self._final_estimator.predict(iter_input, **params)
@@ -661,11 +669,9 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
661669
Result of calling `fit_predict` on the final estimator.
662670
"""
663671
routed_params = self._check_method_params(method="fit_predict", props=params)
664-
iter_input, iter_label = self._fit(
665-
X, y, routed_params
666-
) # pylint: disable=invalid-name
672+
iter_input, iter_label = self._fit(X, y, routed_params)
667673

668-
params_last_step = routed_params[self.steps[-1][0]]
674+
params_last_step = routed_params[self._non_post_processing_steps()[-1][0]]
669675
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
670676
if self._final_estimator == "passthrough":
671677
y_pred = iter_input
@@ -724,7 +730,10 @@ def predict_proba(self, X: Any, **params: Any) -> Any:
724730
elif hasattr(self._final_estimator, "predict_proba"):
725731
if _routing_enabled():
726732
iter_input = self._final_estimator.predict_proba(
727-
iter_input, **routed_params[self.steps[-1][0]].predict_proba
733+
iter_input,
734+
**routed_params[
735+
self._non_post_processing_steps()[-1][0]
736+
].predict_proba,
728737
)
729738
else:
730739
iter_input = self._final_estimator.predict_proba(iter_input, **params)
@@ -854,3 +863,140 @@ def classes_(self) -> list[Any] | npt.NDArray[Any]:
854863
if hasattr(last_step, "classes_"):
855864
return last_step.classes_
856865
raise ValueError("Last step has no classes_ attribute.")
866+
867+
def __sklearn_tags__(self) -> Tags:
868+
"""Return the sklearn tags.
869+
870+
Note
871+
----
872+
This method is copied from the original sklearn implementation.
873+
Changes are marked with a comment.
874+
875+
Returns
876+
-------
877+
Tags
878+
The sklearn tags.
879+
"""
880+
tags = super().__sklearn_tags__()
881+
882+
if not self.steps:
883+
return tags
884+
885+
try:
886+
if self.steps[0][1] is not None and self.steps[0][1] != "passthrough":
887+
tags.input_tags.pairwise = get_tags(
888+
self.steps[0][1]
889+
).input_tags.pairwise
890+
# WARNING: the sparse tag can be incorrect.
891+
# Some Pipelines accepting sparse data are wrongly tagged sparse=False.
892+
# For example Pipeline([PCA(), estimator]) accepts sparse data
893+
# even if the estimator doesn't as PCA outputs a dense array.
894+
tags.input_tags.sparse = all(
895+
get_tags(step).input_tags.sparse
896+
for name, step in self.steps
897+
if step != "passthrough"
898+
)
899+
except (ValueError, AttributeError, TypeError):
900+
# This happens when the `steps` is not a list of (name, estimator)
901+
# tuples and `fit` is not called yet to validate the steps.
902+
pass
903+
904+
try:
905+
# Only the _final_estimator is changed from the original implementation is changed in the following 2 lines
906+
if (
907+
self._final_estimator is not None
908+
and self._final_estimator != "passthrough"
909+
):
910+
last_step_tags = get_tags(self._final_estimator)
911+
tags.estimator_type = last_step_tags.estimator_type
912+
tags.target_tags.multi_output = last_step_tags.target_tags.multi_output
913+
tags.classifier_tags = deepcopy(last_step_tags.classifier_tags)
914+
tags.regressor_tags = deepcopy(last_step_tags.regressor_tags)
915+
tags.transformer_tags = deepcopy(last_step_tags.transformer_tags)
916+
except (ValueError, AttributeError, TypeError):
917+
# This happens when the `steps` is not a list of (name, estimator)
918+
# tuples and `fit` is not called yet to validate the steps.
919+
pass
920+
921+
return tags
922+
923+
def get_metadata_routing(self) -> MetadataRouter:
924+
"""Get metadata routing of this object.
925+
926+
Please check :ref:`User Guide <metadata_routing>` on how the routing
927+
mechanism works.
928+
929+
Note
930+
----
931+
This method is copied from the original sklearn implementation.
932+
Changes are marked with a comment.
933+
934+
Returns
935+
-------
936+
MetadataRouter
937+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
938+
routing information.
939+
"""
940+
router = MetadataRouter(owner=self.__class__.__name__)
941+
942+
# first we add all steps except the last one
943+
for _, name, trans in self._iter(with_final=False, filter_passthrough=True):
944+
method_mapping = MethodMapping()
945+
# fit, fit_predict, and fit_transform call fit_transform if it
946+
# exists, or else fit and transform
947+
if hasattr(trans, "fit_transform"):
948+
(
949+
method_mapping.add(caller="fit", callee="fit_transform")
950+
.add(caller="fit_transform", callee="fit_transform")
951+
.add(caller="fit_predict", callee="fit_transform")
952+
)
953+
else:
954+
(
955+
method_mapping.add(caller="fit", callee="fit")
956+
.add(caller="fit", callee="transform")
957+
.add(caller="fit_transform", callee="fit")
958+
.add(caller="fit_transform", callee="transform")
959+
.add(caller="fit_predict", callee="fit")
960+
.add(caller="fit_predict", callee="transform")
961+
)
962+
963+
(
964+
method_mapping.add(caller="predict", callee="transform")
965+
.add(caller="predict", callee="transform")
966+
.add(caller="predict_proba", callee="transform")
967+
.add(caller="decision_function", callee="transform")
968+
.add(caller="predict_log_proba", callee="transform")
969+
.add(caller="transform", callee="transform")
970+
.add(caller="inverse_transform", callee="inverse_transform")
971+
.add(caller="score", callee="transform")
972+
)
973+
974+
router.add(method_mapping=method_mapping, **{name: trans})
975+
976+
# Only the _non_post_processing_steps is changed from the original implementation is changed in the following line
977+
final_name, final_est = self._non_post_processing_steps()[-1]
978+
if final_est is None or final_est == "passthrough":
979+
return router
980+
981+
# then we add the last step
982+
method_mapping = MethodMapping()
983+
if hasattr(final_est, "fit_transform"):
984+
method_mapping.add(caller="fit_transform", callee="fit_transform")
985+
else:
986+
method_mapping.add(caller="fit", callee="fit").add(
987+
caller="fit", callee="transform"
988+
)
989+
(
990+
method_mapping.add(caller="fit", callee="fit")
991+
.add(caller="predict", callee="predict")
992+
.add(caller="fit_predict", callee="fit_predict")
993+
.add(caller="predict_proba", callee="predict_proba")
994+
.add(caller="decision_function", callee="decision_function")
995+
.add(caller="predict_log_proba", callee="predict_log_proba")
996+
.add(caller="transform", callee="transform")
997+
.add(caller="inverse_transform", callee="inverse_transform")
998+
.add(caller="score", callee="score")
999+
)
1000+
1001+
router.add(method_mapping=method_mapping, **{final_name: final_est})
1002+
return router

0 commit comments

Comments
 (0)