Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure compatibility with sklearn and implement tags #124

Merged
merged 18 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions molpipeline/estimators/chemprop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import numpy.typing as npt
from loguru import logger
from sklearn.base import clone
from sklearn.utils._tags import (
ClassifierTags,
RegressorTags,
Tags,
)
from sklearn.utils.metaestimators import available_if

try:
Expand Down Expand Up @@ -124,6 +129,24 @@ def _is_classifier(self) -> bool:
"""
return self._is_binary_classifier() or self._is_multiclass_classifier()

def __sklearn_tags__(self) -> Tags:
"""Return the sklearn tags.

Returns
-------
Tags
The sklearn tags for the model.
"""
tags = super().__sklearn_tags__()
if self._is_classifier():
tags.estimator_type = "classifier"
tags.classifier_tags = ClassifierTags()
else:
tags.estimator_type = "regressor"
tags.regressor_tags = RegressorTags()
tags.target_tags.required = True
return tags

def _predict(
self, X: MoleculeDataset # pylint: disable=invalid-name
) -> npt.NDArray[np.float64]:
Expand Down
176 changes: 161 additions & 15 deletions molpipeline/pipeline/_skl_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Defines a pipeline is exposed to the user, accessible via pipeline."""

# pylint: disable=too-many-lines

from __future__ import annotations

from copy import deepcopy
from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union

try:
Expand All @@ -13,15 +16,15 @@
import numpy as np
import numpy.typing as npt
from loguru import logger
from sklearn.base import _fit_context # pylint: disable=protected-access
from sklearn.base import clone
from sklearn.base import _fit_context, clone
from sklearn.pipeline import Pipeline as _Pipeline
from sklearn.pipeline import _final_estimator_has, _fit_transform_one
from sklearn.utils import Bunch
from sklearn.utils._tags import Tags, get_tags
from sklearn.utils.metadata_routing import (
_routing_enabled, # pylint: disable=protected-access
)
from sklearn.utils.metadata_routing import (
MetadataRouter,
MethodMapping,
_routing_enabled,
process_routing,
)
from sklearn.utils.metaestimators import available_if
Expand Down Expand Up @@ -213,8 +216,8 @@ def _final_estimator(
# pylint: disable=too-many-locals,too-many-branches
def _fit(
self,
X: Any, # pylint: disable=invalid-name
y: Any = None, # pylint: disable=invalid-name
X: Any,
y: Any = None,
routed_params: dict[str, Any] | None = None,
raw_params: dict[str, Any] | None = None,
) -> tuple[Any, Any]:
Expand Down Expand Up @@ -482,7 +485,9 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
"All input rows were filtered out! Model is not fitted!"
)
else:
fit_params_last_step = routed_params[self.steps[-1][0]]
fit_params_last_step = routed_params[
self._non_post_processing_steps()[-1][0]
]
self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"])

return self
Expand Down Expand Up @@ -552,7 +557,9 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
elif is_empty(iter_input):
logger.warning("All input rows were filtered out! Model is not fitted!")
else:
last_step_params = routed_params[self.steps[-1][0]]
last_step_params = routed_params[
self._non_post_processing_steps()[-1][0]
]
if hasattr(last_step, "fit_transform"):
iter_input = last_step.fit_transform(
iter_input, iter_label, **last_step_params["fit_transform"]
Expand Down Expand Up @@ -615,7 +622,8 @@ def predict(self, X: Any, **params: Any) -> Any:
elif hasattr(self._final_estimator, "predict"):
if _routing_enabled():
iter_input = self._final_estimator.predict(
iter_input, **routed_params[self._final_estimator].predict
iter_input,
**routed_params[self._non_post_processing_steps()[-1][0]].predict,
)
else:
iter_input = self._final_estimator.predict(iter_input, **params)
Expand Down Expand Up @@ -661,11 +669,9 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
Result of calling `fit_predict` on the final estimator.
"""
routed_params = self._check_method_params(method="fit_predict", props=params)
iter_input, iter_label = self._fit(
X, y, routed_params
) # pylint: disable=invalid-name
iter_input, iter_label = self._fit(X, y, routed_params)

params_last_step = routed_params[self.steps[-1][0]]
params_last_step = routed_params[self._non_post_processing_steps()[-1][0]]
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator == "passthrough":
y_pred = iter_input
Expand Down Expand Up @@ -724,7 +730,10 @@ def predict_proba(self, X: Any, **params: Any) -> Any:
elif hasattr(self._final_estimator, "predict_proba"):
if _routing_enabled():
iter_input = self._final_estimator.predict_proba(
iter_input, **routed_params[self.steps[-1][0]].predict_proba
iter_input,
**routed_params[
self._non_post_processing_steps()[-1][0]
].predict_proba,
)
else:
iter_input = self._final_estimator.predict_proba(iter_input, **params)
Expand Down Expand Up @@ -854,3 +863,140 @@ def classes_(self) -> list[Any] | npt.NDArray[Any]:
if hasattr(last_step, "classes_"):
return last_step.classes_
raise ValueError("Last step has no classes_ attribute.")

def __sklearn_tags__(self) -> Tags:
"""Return the sklearn tags.

Note
----
This method is copied from the original sklearn implementation.
Changes are marked with a comment.

Returns
-------
Tags
The sklearn tags.
"""
tags = super().__sklearn_tags__()

if not self.steps:
return tags

try:
if self.steps[0][1] is not None and self.steps[0][1] != "passthrough":
tags.input_tags.pairwise = get_tags(
self.steps[0][1]
).input_tags.pairwise
# WARNING: the sparse tag can be incorrect.
# Some Pipelines accepting sparse data are wrongly tagged sparse=False.
# For example Pipeline([PCA(), estimator]) accepts sparse data
# even if the estimator doesn't as PCA outputs a dense array.
tags.input_tags.sparse = all(
get_tags(step).input_tags.sparse
for name, step in self.steps
if step != "passthrough"
)
except (ValueError, AttributeError, TypeError):
# This happens when the `steps` is not a list of (name, estimator)
# tuples and `fit` is not called yet to validate the steps.
pass

try:
# Only the _final_estimator is changed from the original implementation is changed in the following 2 lines
if (
self._final_estimator is not None
and self._final_estimator != "passthrough"
):
last_step_tags = get_tags(self._final_estimator)
tags.estimator_type = last_step_tags.estimator_type
tags.target_tags.multi_output = last_step_tags.target_tags.multi_output
tags.classifier_tags = deepcopy(last_step_tags.classifier_tags)
tags.regressor_tags = deepcopy(last_step_tags.regressor_tags)
tags.transformer_tags = deepcopy(last_step_tags.transformer_tags)
except (ValueError, AttributeError, TypeError):
# This happens when the `steps` is not a list of (name, estimator)
# tuples and `fit` is not called yet to validate the steps.
pass

return tags

def get_metadata_routing(self) -> MetadataRouter:
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

Note
----
This method is copied from the original sklearn implementation.
Changes are marked with a comment.

Returns
-------
MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__)

# first we add all steps except the last one
for _, name, trans in self._iter(with_final=False, filter_passthrough=True):
method_mapping = MethodMapping()
# fit, fit_predict, and fit_transform call fit_transform if it
# exists, or else fit and transform
if hasattr(trans, "fit_transform"):
(
method_mapping.add(caller="fit", callee="fit_transform")
.add(caller="fit_transform", callee="fit_transform")
.add(caller="fit_predict", callee="fit_transform")
)
else:
(
method_mapping.add(caller="fit", callee="fit")
.add(caller="fit", callee="transform")
.add(caller="fit_transform", callee="fit")
.add(caller="fit_transform", callee="transform")
.add(caller="fit_predict", callee="fit")
.add(caller="fit_predict", callee="transform")
)

(
method_mapping.add(caller="predict", callee="transform")
.add(caller="predict", callee="transform")
.add(caller="predict_proba", callee="transform")
.add(caller="decision_function", callee="transform")
.add(caller="predict_log_proba", callee="transform")
.add(caller="transform", callee="transform")
.add(caller="inverse_transform", callee="inverse_transform")
.add(caller="score", callee="transform")
)

router.add(method_mapping=method_mapping, **{name: trans})

# Only the _non_post_processing_steps is changed from the original implementation is changed in the following line
final_name, final_est = self._non_post_processing_steps()[-1]
if final_est is None or final_est == "passthrough":
return router

# then we add the last step
method_mapping = MethodMapping()
if hasattr(final_est, "fit_transform"):
method_mapping.add(caller="fit_transform", callee="fit_transform")
else:
method_mapping.add(caller="fit", callee="fit").add(
caller="fit", callee="transform"
)
(
method_mapping.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict")
.add(caller="fit_predict", callee="fit_predict")
.add(caller="predict_proba", callee="predict_proba")
.add(caller="decision_function", callee="decision_function")
.add(caller="predict_log_proba", callee="predict_log_proba")
.add(caller="transform", callee="transform")
.add(caller="inverse_transform", callee="inverse_transform")
.add(caller="score", callee="score")
)

router.add(method_mapping=method_mapping, **{final_name: final_est})
return router
Loading