|
1 | 1 | """Defines a pipeline is exposed to the user, accessible via pipeline."""
|
2 | 2 |
|
| 3 | +# pylint: disable=too-many-lines |
| 4 | + |
3 | 5 | from __future__ import annotations
|
4 | 6 |
|
| 7 | +from copy import deepcopy |
5 | 8 | from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union
|
6 | 9 |
|
7 | 10 | try:
|
|
13 | 16 | import numpy as np
|
14 | 17 | import numpy.typing as npt
|
15 | 18 | from loguru import logger
|
16 |
| -from sklearn.base import _fit_context # pylint: disable=protected-access |
17 |
| -from sklearn.base import clone |
| 19 | +from sklearn.base import _fit_context, clone |
18 | 20 | from sklearn.pipeline import Pipeline as _Pipeline
|
19 | 21 | from sklearn.pipeline import _final_estimator_has, _fit_transform_one
|
20 | 22 | from sklearn.utils import Bunch
|
| 23 | +from sklearn.utils._tags import Tags, get_tags |
21 | 24 | from sklearn.utils.metadata_routing import (
|
22 |
| - _routing_enabled, # pylint: disable=protected-access |
23 |
| -) |
24 |
| -from sklearn.utils.metadata_routing import ( |
| 25 | + MetadataRouter, |
| 26 | + MethodMapping, |
| 27 | + _routing_enabled, |
25 | 28 | process_routing,
|
26 | 29 | )
|
27 | 30 | from sklearn.utils.metaestimators import available_if
|
@@ -213,8 +216,8 @@ def _final_estimator(
|
213 | 216 | # pylint: disable=too-many-locals,too-many-branches
|
214 | 217 | def _fit(
|
215 | 218 | self,
|
216 |
| - X: Any, # pylint: disable=invalid-name |
217 |
| - y: Any = None, # pylint: disable=invalid-name |
| 219 | + X: Any, |
| 220 | + y: Any = None, |
218 | 221 | routed_params: dict[str, Any] | None = None,
|
219 | 222 | raw_params: dict[str, Any] | None = None,
|
220 | 223 | ) -> tuple[Any, Any]:
|
@@ -482,7 +485,9 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self:
|
482 | 485 | "All input rows were filtered out! Model is not fitted!"
|
483 | 486 | )
|
484 | 487 | else:
|
485 |
| - fit_params_last_step = routed_params[self.steps[-1][0]] |
| 488 | + fit_params_last_step = routed_params[ |
| 489 | + self._non_post_processing_steps()[-1][0] |
| 490 | + ] |
486 | 491 | self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"])
|
487 | 492 |
|
488 | 493 | return self
|
@@ -552,7 +557,9 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any:
|
552 | 557 | elif is_empty(iter_input):
|
553 | 558 | logger.warning("All input rows were filtered out! Model is not fitted!")
|
554 | 559 | else:
|
555 |
| - last_step_params = routed_params[self.steps[-1][0]] |
| 560 | + last_step_params = routed_params[ |
| 561 | + self._non_post_processing_steps()[-1][0] |
| 562 | + ] |
556 | 563 | if hasattr(last_step, "fit_transform"):
|
557 | 564 | iter_input = last_step.fit_transform(
|
558 | 565 | iter_input, iter_label, **last_step_params["fit_transform"]
|
@@ -615,7 +622,8 @@ def predict(self, X: Any, **params: Any) -> Any:
|
615 | 622 | elif hasattr(self._final_estimator, "predict"):
|
616 | 623 | if _routing_enabled():
|
617 | 624 | iter_input = self._final_estimator.predict(
|
618 |
| - iter_input, **routed_params[self._final_estimator].predict |
| 625 | + iter_input, |
| 626 | + **routed_params[self._non_post_processing_steps()[-1][0]].predict, |
619 | 627 | )
|
620 | 628 | else:
|
621 | 629 | iter_input = self._final_estimator.predict(iter_input, **params)
|
@@ -661,11 +669,9 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any:
|
661 | 669 | Result of calling `fit_predict` on the final estimator.
|
662 | 670 | """
|
663 | 671 | routed_params = self._check_method_params(method="fit_predict", props=params)
|
664 |
| - iter_input, iter_label = self._fit( |
665 |
| - X, y, routed_params |
666 |
| - ) # pylint: disable=invalid-name |
| 672 | + iter_input, iter_label = self._fit(X, y, routed_params) |
667 | 673 |
|
668 |
| - params_last_step = routed_params[self.steps[-1][0]] |
| 674 | + params_last_step = routed_params[self._non_post_processing_steps()[-1][0]] |
669 | 675 | with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
|
670 | 676 | if self._final_estimator == "passthrough":
|
671 | 677 | y_pred = iter_input
|
@@ -724,7 +730,10 @@ def predict_proba(self, X: Any, **params: Any) -> Any:
|
724 | 730 | elif hasattr(self._final_estimator, "predict_proba"):
|
725 | 731 | if _routing_enabled():
|
726 | 732 | iter_input = self._final_estimator.predict_proba(
|
727 |
| - iter_input, **routed_params[self.steps[-1][0]].predict_proba |
| 733 | + iter_input, |
| 734 | + **routed_params[ |
| 735 | + self._non_post_processing_steps()[-1][0] |
| 736 | + ].predict_proba, |
728 | 737 | )
|
729 | 738 | else:
|
730 | 739 | iter_input = self._final_estimator.predict_proba(iter_input, **params)
|
@@ -854,3 +863,140 @@ def classes_(self) -> list[Any] | npt.NDArray[Any]:
|
854 | 863 | if hasattr(last_step, "classes_"):
|
855 | 864 | return last_step.classes_
|
856 | 865 | raise ValueError("Last step has no classes_ attribute.")
|
| 866 | + |
| 867 | + def __sklearn_tags__(self) -> Tags: |
| 868 | + """Return the sklearn tags. |
| 869 | +
|
| 870 | + Note |
| 871 | + ---- |
| 872 | + This method is copied from the original sklearn implementation. |
| 873 | + Changes are marked with a comment. |
| 874 | +
|
| 875 | + Returns |
| 876 | + ------- |
| 877 | + Tags |
| 878 | + The sklearn tags. |
| 879 | + """ |
| 880 | + tags = super().__sklearn_tags__() |
| 881 | + |
| 882 | + if not self.steps: |
| 883 | + return tags |
| 884 | + |
| 885 | + try: |
| 886 | + if self.steps[0][1] is not None and self.steps[0][1] != "passthrough": |
| 887 | + tags.input_tags.pairwise = get_tags( |
| 888 | + self.steps[0][1] |
| 889 | + ).input_tags.pairwise |
| 890 | + # WARNING: the sparse tag can be incorrect. |
| 891 | + # Some Pipelines accepting sparse data are wrongly tagged sparse=False. |
| 892 | + # For example Pipeline([PCA(), estimator]) accepts sparse data |
| 893 | + # even if the estimator doesn't as PCA outputs a dense array. |
| 894 | + tags.input_tags.sparse = all( |
| 895 | + get_tags(step).input_tags.sparse |
| 896 | + for name, step in self.steps |
| 897 | + if step != "passthrough" |
| 898 | + ) |
| 899 | + except (ValueError, AttributeError, TypeError): |
| 900 | + # This happens when the `steps` is not a list of (name, estimator) |
| 901 | + # tuples and `fit` is not called yet to validate the steps. |
| 902 | + pass |
| 903 | + |
| 904 | + try: |
| 905 | + # Only the _final_estimator is changed from the original implementation is changed in the following 2 lines |
| 906 | + if ( |
| 907 | + self._final_estimator is not None |
| 908 | + and self._final_estimator != "passthrough" |
| 909 | + ): |
| 910 | + last_step_tags = get_tags(self._final_estimator) |
| 911 | + tags.estimator_type = last_step_tags.estimator_type |
| 912 | + tags.target_tags.multi_output = last_step_tags.target_tags.multi_output |
| 913 | + tags.classifier_tags = deepcopy(last_step_tags.classifier_tags) |
| 914 | + tags.regressor_tags = deepcopy(last_step_tags.regressor_tags) |
| 915 | + tags.transformer_tags = deepcopy(last_step_tags.transformer_tags) |
| 916 | + except (ValueError, AttributeError, TypeError): |
| 917 | + # This happens when the `steps` is not a list of (name, estimator) |
| 918 | + # tuples and `fit` is not called yet to validate the steps. |
| 919 | + pass |
| 920 | + |
| 921 | + return tags |
| 922 | + |
| 923 | + def get_metadata_routing(self) -> MetadataRouter: |
| 924 | + """Get metadata routing of this object. |
| 925 | +
|
| 926 | + Please check :ref:`User Guide <metadata_routing>` on how the routing |
| 927 | + mechanism works. |
| 928 | +
|
| 929 | + Note |
| 930 | + ---- |
| 931 | + This method is copied from the original sklearn implementation. |
| 932 | + Changes are marked with a comment. |
| 933 | +
|
| 934 | + Returns |
| 935 | + ------- |
| 936 | + MetadataRouter |
| 937 | + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating |
| 938 | + routing information. |
| 939 | + """ |
| 940 | + router = MetadataRouter(owner=self.__class__.__name__) |
| 941 | + |
| 942 | + # first we add all steps except the last one |
| 943 | + for _, name, trans in self._iter(with_final=False, filter_passthrough=True): |
| 944 | + method_mapping = MethodMapping() |
| 945 | + # fit, fit_predict, and fit_transform call fit_transform if it |
| 946 | + # exists, or else fit and transform |
| 947 | + if hasattr(trans, "fit_transform"): |
| 948 | + ( |
| 949 | + method_mapping.add(caller="fit", callee="fit_transform") |
| 950 | + .add(caller="fit_transform", callee="fit_transform") |
| 951 | + .add(caller="fit_predict", callee="fit_transform") |
| 952 | + ) |
| 953 | + else: |
| 954 | + ( |
| 955 | + method_mapping.add(caller="fit", callee="fit") |
| 956 | + .add(caller="fit", callee="transform") |
| 957 | + .add(caller="fit_transform", callee="fit") |
| 958 | + .add(caller="fit_transform", callee="transform") |
| 959 | + .add(caller="fit_predict", callee="fit") |
| 960 | + .add(caller="fit_predict", callee="transform") |
| 961 | + ) |
| 962 | + |
| 963 | + ( |
| 964 | + method_mapping.add(caller="predict", callee="transform") |
| 965 | + .add(caller="predict", callee="transform") |
| 966 | + .add(caller="predict_proba", callee="transform") |
| 967 | + .add(caller="decision_function", callee="transform") |
| 968 | + .add(caller="predict_log_proba", callee="transform") |
| 969 | + .add(caller="transform", callee="transform") |
| 970 | + .add(caller="inverse_transform", callee="inverse_transform") |
| 971 | + .add(caller="score", callee="transform") |
| 972 | + ) |
| 973 | + |
| 974 | + router.add(method_mapping=method_mapping, **{name: trans}) |
| 975 | + |
| 976 | + # Only the _non_post_processing_steps is changed from the original implementation is changed in the following line |
| 977 | + final_name, final_est = self._non_post_processing_steps()[-1] |
| 978 | + if final_est is None or final_est == "passthrough": |
| 979 | + return router |
| 980 | + |
| 981 | + # then we add the last step |
| 982 | + method_mapping = MethodMapping() |
| 983 | + if hasattr(final_est, "fit_transform"): |
| 984 | + method_mapping.add(caller="fit_transform", callee="fit_transform") |
| 985 | + else: |
| 986 | + method_mapping.add(caller="fit", callee="fit").add( |
| 987 | + caller="fit", callee="transform" |
| 988 | + ) |
| 989 | + ( |
| 990 | + method_mapping.add(caller="fit", callee="fit") |
| 991 | + .add(caller="predict", callee="predict") |
| 992 | + .add(caller="fit_predict", callee="fit_predict") |
| 993 | + .add(caller="predict_proba", callee="predict_proba") |
| 994 | + .add(caller="decision_function", callee="decision_function") |
| 995 | + .add(caller="predict_log_proba", callee="predict_log_proba") |
| 996 | + .add(caller="transform", callee="transform") |
| 997 | + .add(caller="inverse_transform", callee="inverse_transform") |
| 998 | + .add(caller="score", callee="score") |
| 999 | + ) |
| 1000 | + |
| 1001 | + router.add(method_mapping=method_mapping, **{final_name: final_est}) |
| 1002 | + return router |
0 commit comments