Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure compatibility with sklearn and implement tags #124

Merged
merged 18 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 143 additions & 5 deletions molpipeline/pipeline/_skl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from copy import deepcopy
from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union

try:
Expand All @@ -18,10 +19,13 @@
from sklearn.pipeline import Pipeline as _Pipeline
from sklearn.pipeline import _final_estimator_has, _fit_transform_one
from sklearn.utils import Bunch
from sklearn.utils._tags import Tags, get_tags # pylint: disable=protected-access
from sklearn.utils.metadata_routing import (
_routing_enabled, # pylint: disable=protected-access
)
from sklearn.utils.metadata_routing import (
MetadataRouter,
MethodMapping,
process_routing,
)
from sklearn.utils.metaestimators import available_if
Expand Down Expand Up @@ -482,7 +486,9 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
"All input rows were filtered out! Model is not fitted!"
)
else:
fit_params_last_step = routed_params[self.steps[-1][0]]
fit_params_last_step = routed_params[
self._non_post_processing_steps()[-1][0]
]
self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"])

return self
Expand Down Expand Up @@ -552,7 +558,9 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
elif is_empty(iter_input):
logger.warning("All input rows were filtered out! Model is not fitted!")
else:
last_step_params = routed_params[self.steps[-1][0]]
last_step_params = routed_params[
self._non_post_processing_steps()[-1][0]
]
if hasattr(last_step, "fit_transform"):
iter_input = last_step.fit_transform(
iter_input, iter_label, **last_step_params["fit_transform"]
Expand Down Expand Up @@ -615,7 +623,8 @@ def predict(self, X: Any, **params: Any) -> Any:
elif hasattr(self._final_estimator, "predict"):
if _routing_enabled():
iter_input = self._final_estimator.predict(
iter_input, **routed_params[self._final_estimator].predict
iter_input,
**routed_params[self._non_post_processing_steps()[-1][0]].predict,
)
else:
iter_input = self._final_estimator.predict(iter_input, **params)
Expand Down Expand Up @@ -665,7 +674,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
X, y, routed_params
) # pylint: disable=invalid-name

params_last_step = routed_params[self.steps[-1][0]]
params_last_step = routed_params[self._non_post_processing_steps()[-1][0]]
with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator == "passthrough":
y_pred = iter_input
Expand Down Expand Up @@ -724,7 +733,10 @@ def predict_proba(self, X: Any, **params: Any) -> Any:
elif hasattr(self._final_estimator, "predict_proba"):
if _routing_enabled():
iter_input = self._final_estimator.predict_proba(
iter_input, **routed_params[self.steps[-1][0]].predict_proba
iter_input,
**routed_params[
self._non_post_processing_steps()[-1][0]
].predict_proba,
)
else:
iter_input = self._final_estimator.predict_proba(iter_input, **params)
Expand Down Expand Up @@ -854,3 +866,129 @@ def classes_(self) -> list[Any] | npt.NDArray[Any]:
if hasattr(last_step, "classes_"):
return last_step.classes_
raise ValueError("Last step has no classes_ attribute.")

def __sklearn_tags__(self) -> Tags:
"""Return the sklearn tags.

Returns
-------
Tags
The sklearn tags.
"""
tags = super().__sklearn_tags__()

if not self.steps:
return tags

try:
if self.steps[0][1] is not None and self.steps[0][1] != "passthrough":
tags.input_tags.pairwise = get_tags(
self.steps[0][1]
).input_tags.pairwise
# WARNING: the sparse tag can be incorrect.
# Some Pipelines accepting sparse data are wrongly tagged sparse=False.
# For example Pipeline([PCA(), estimator]) accepts sparse data
# even if the estimator doesn't as PCA outputs a dense array.
tags.input_tags.sparse = all(
get_tags(step).input_tags.sparse
for name, step in self.steps
if step != "passthrough"
)
except (ValueError, AttributeError, TypeError):
# This happens when the `steps` is not a list of (name, estimator)
# tuples and `fit` is not called yet to validate the steps.
pass

try:
# Only the _final_estimator is changed from the original implementation
if (
self._final_estimator is not None
and self._final_estimator != "passthrough"
):
last_step_tags = get_tags(self._final_estimator)
tags.estimator_type = last_step_tags.estimator_type
tags.target_tags.multi_output = last_step_tags.target_tags.multi_output
tags.classifier_tags = deepcopy(last_step_tags.classifier_tags)
tags.regressor_tags = deepcopy(last_step_tags.regressor_tags)
tags.transformer_tags = deepcopy(last_step_tags.transformer_tags)
except (ValueError, AttributeError, TypeError):
# This happens when the `steps` is not a list of (name, estimator)
# tuples and `fit` is not called yet to validate the steps.
pass

return tags

def get_metadata_routing(self) -> MetadataRouter:
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

Returns
-------
MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__)

# first we add all steps except the last one
for _, name, trans in self._iter(with_final=False, filter_passthrough=True):
method_mapping = MethodMapping()
# fit, fit_predict, and fit_transform call fit_transform if it
# exists, or else fit and transform
if hasattr(trans, "fit_transform"):
(
method_mapping.add(caller="fit", callee="fit_transform")
.add(caller="fit_transform", callee="fit_transform")
.add(caller="fit_predict", callee="fit_transform")
)
else:
(
method_mapping.add(caller="fit", callee="fit")
.add(caller="fit", callee="transform")
.add(caller="fit_transform", callee="fit")
.add(caller="fit_transform", callee="transform")
.add(caller="fit_predict", callee="fit")
.add(caller="fit_predict", callee="transform")
)

(
method_mapping.add(caller="predict", callee="transform")
.add(caller="predict", callee="transform")
.add(caller="predict_proba", callee="transform")
.add(caller="decision_function", callee="transform")
.add(caller="predict_log_proba", callee="transform")
.add(caller="transform", callee="transform")
.add(caller="inverse_transform", callee="inverse_transform")
.add(caller="score", callee="transform")
)

router.add(method_mapping=method_mapping, **{name: trans})

final_name, final_est = self._non_post_processing_steps()[-1]
if final_est is None or final_est == "passthrough":
return router

# then we add the last step
method_mapping = MethodMapping()
if hasattr(final_est, "fit_transform"):
method_mapping.add(caller="fit_transform", callee="fit_transform")
else:
method_mapping.add(caller="fit", callee="fit").add(
caller="fit", callee="transform"
)
(
method_mapping.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict")
.add(caller="fit_predict", callee="fit_predict")
.add(caller="predict_proba", callee="predict_proba")
.add(caller="decision_function", callee="decision_function")
.add(caller="predict_log_proba", callee="predict_log_proba")
.add(caller="transform", callee="transform")
.add(caller="inverse_transform", callee="inverse_transform")
.add(caller="score", callee="score")
)

router.add(method_mapping=method_mapping, **{final_name: final_est})
return router
38 changes: 37 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
import pandas as pd
from joblib import Memory
from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier

from molpipeline import ErrorFilter, Pipeline
from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper
from molpipeline.any2mol import AutoToMol, SmilesToMol
from molpipeline.mol2any import MolToMorganFP, MolToRDKitPhysChem, MolToSmiles
from molpipeline.mol2mol import (
Expand Down Expand Up @@ -375,5 +376,40 @@ def test_gridsearch_cache(self) -> None:
self.assertTrue(np.allclose(prediction_dict[True], prediction_dict[False]))


class PipelineCompatibilityTest(unittest.TestCase):
"""Test if the pipeline is compatible with other sklearn functionalities."""

def test_calibrated_classifier(self) -> None:
"""Test if the pipeline can be used with a CalibratedClassifierCV."""
smi2mol = SmilesToMol()
mol2morgan = MolToMorganFP(radius=FP_RADIUS, n_bits=FP_SIZE)
d_tree = DecisionTreeClassifier()
error_filter = ErrorFilter(filter_everything=True)
s_pipeline = Pipeline(
[
("smi2mol", smi2mol),
("morgan", mol2morgan),
("error_filter", error_filter),
("decision_tree", d_tree),
(
"error_replacer",
PostPredictionWrapper(
FilterReinserter.from_error_filter(error_filter, np.nan)
),
),
]
)
calibrated_pipeline = CalibratedClassifierCV(
s_pipeline, cv=2, ensemble=True, method="isotonic"
)
calibrated_pipeline.fit(TEST_SMILES, CONTAINS_OX)
predicted_value_array = calibrated_pipeline.predict(TEST_SMILES)
predicted_proba_array = calibrated_pipeline.predict_proba(TEST_SMILES)
self.assertIsInstance(predicted_value_array, np.ndarray)
self.assertIsInstance(predicted_proba_array, np.ndarray)
self.assertEqual(predicted_value_array.shape, (len(TEST_SMILES),))
self.assertEqual(predicted_proba_array.shape, (len(TEST_SMILES),))


if __name__ == "__main__":
unittest.main()
Loading