Skip to content

Commit 9209452

Browse files
committed
Fix prediction fails with MOO ensemble and dummy is best (#1518)
* Init commit * Fix DummyClassifiers in _load_pareto_set * Add test for dummy only in classifiers * Update no ensemble docstring * Add automl case where automl only has dummy * Remove tmp file * Fix `include` statement to be regressor
1 parent 9914168 commit 9209452

File tree

2 files changed

+131
-16
lines changed

2 files changed

+131
-16
lines changed

autosklearn/automl.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
BaseShuffleSplit,
4949
_RepeatedSplits,
5050
)
51+
from sklearn.pipeline import Pipeline
5152
from sklearn.utils import check_random_state
5253
from sklearn.utils.validation import check_is_fitted
5354
from smac.callbacks import IncorporateRunResultCallback
@@ -1473,6 +1474,7 @@ def predict(self, X, batch_size=None, n_jobs=1):
14731474
# Each process computes predictions in chunks of batch_size rows.
14741475
try:
14751476
for i, tmp_model in enumerate(self.models_.values()):
1477+
# TODO, modify this
14761478
if isinstance(tmp_model, (DummyRegressor, DummyClassifier)):
14771479
check_is_fitted(tmp_model)
14781480
else:
@@ -1683,10 +1685,8 @@ def _load_best_individual_model(self):
16831685
return ensemble
16841686

16851687
def _load_pareto_set(self) -> Sequence[VotingClassifier | VotingRegressor]:
1686-
if self._ensemble_class is not None:
1688+
if self.ensemble_ is None:
16871689
self.ensemble_ = self._backend.load_ensemble(self._seed)
1688-
else:
1689-
self.ensemble_ = None
16901690

16911691
# If no ensemble is loaded we cannot do anything
16921692
if not self.ensemble_:
@@ -1716,8 +1716,10 @@ def _load_pareto_set(self) -> Sequence[VotingClassifier | VotingRegressor]:
17161716
estimators=None,
17171717
voting="soft",
17181718
)
1719+
kind = "classifier"
17191720
else:
17201721
voter = VotingRegressor(estimators=None)
1722+
kind = "regeressor"
17211723

17221724
if self._resampling_strategy in ("cv", "cv-iterative-fit"):
17231725
models = self._backend.load_cv_models_by_identifiers(identifiers)
@@ -1730,8 +1732,32 @@ def _load_pareto_set(self) -> Sequence[VotingClassifier | VotingRegressor]:
17301732
weight_vector = []
17311733
estimators = []
17321734
for identifier in identifiers:
1733-
weight_vector.append(weights[identifier])
1734-
estimators.append(models[identifier])
1735+
estimator = models[identifier]
1736+
weight = weights[identifier]
1737+
1738+
# Kind of hacky, really the dummy models should
1739+
# act like everything else does. Doing this is
1740+
# required so that the VotingClassifier/Regressor
1741+
# can use it as intended
1742+
if not isinstance(estimator, Pipeline):
1743+
if kind == "classifier":
1744+
steps = [
1745+
("data_preprocessor", None),
1746+
("balancing", None),
1747+
("feature_preprocessor", None),
1748+
(kind, estimator),
1749+
]
1750+
else:
1751+
steps = [
1752+
("data_preprocessor", None),
1753+
("feature_preprocessor", None),
1754+
(kind, estimator),
1755+
]
1756+
1757+
estimator = Pipeline(steps=steps)
1758+
1759+
weight_vector.append(weight)
1760+
estimators.append(estimator)
17351761

17361762
voter.estimators = estimators
17371763
voter.estimators_ = estimators
@@ -2148,7 +2174,7 @@ def show_models(self) -> dict[int, Any]:
21482174

21492175
ensemble_dict = {}
21502176

2151-
if self._ensemble_class is not None:
2177+
if self._ensemble_class is None:
21522178
warnings.warn(
21532179
"No models in the ensemble. Kindly provide an ensemble class."
21542180
)

test/test_automl/cases.py

+99-10
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
{fitted} - If the automl case has been fitted
1515
{cv, holdout} - Whether explicitly cv or holdout was used
1616
{no_ensemble} - Fit with no ensemble size
17-
{cached} - If the resulting case is then cached
1817
{multiobjective} - If the automl instance is multiobjective
1918
"""
2019
from __future__ import annotations
@@ -24,17 +23,27 @@
2423
from pathlib import Path
2524

2625
import numpy as np
26+
import sklearn.model_selection
2727

2828
import autosklearn.metrics
2929
from autosklearn.automl import AutoMLClassifier, AutoMLRegressor
3030
from autosklearn.automl_common.common.utils.backend import Backend
31+
from autosklearn.evaluation.abstract_evaluator import (
32+
MyDummyClassifier,
33+
MyDummyRegressor,
34+
)
3135

3236
from pytest_cases import case, parametrize
3337

3438
from test.fixtures.backend import copy_backend
3539
from test.fixtures.caching import Cache
3640

3741

42+
def stop_at_first(smbo, run_info, result, time_left) -> bool:
43+
"""Used in some cases to enforce the only valid model is the dummy model"""
44+
return False
45+
46+
3847
@case(tags=["classifier"])
3948
def case_classifier(
4049
tmp_dir: str,
@@ -60,7 +69,7 @@ def case_regressor(
6069
# ###################################
6170
# The following are fitted and cached
6271
# ###################################
63-
@case(tags=["classifier", "fitted", "holdout", "cached"])
72+
@case(tags=["classifier", "fitted", "holdout"])
6473
@parametrize("dataset", ["iris"])
6574
def case_classifier_fitted_holdout_iterative(
6675
dataset: str,
@@ -97,7 +106,7 @@ def case_classifier_fitted_holdout_iterative(
97106
return model
98107

99108

100-
@case(tags=["classifier", "fitted", "cv", "cached"])
109+
@case(tags=["classifier", "fitted", "cv"])
101110
@parametrize("dataset", ["iris"])
102111
def case_classifier_fitted_cv(
103112
make_cache: Callable[[str], Cache],
@@ -134,7 +143,7 @@ def case_classifier_fitted_cv(
134143
return model
135144

136145

137-
@case(tags=["classifier", "fitted", "holdout", "cached", "multiobjective"])
146+
@case(tags=["classifier", "fitted", "holdout", "multiobjective"])
138147
@parametrize("dataset", ["iris"])
139148
def case_classifier_fitted_holdout_multiobjective(
140149
dataset: str,
@@ -177,7 +186,7 @@ def case_classifier_fitted_holdout_multiobjective(
177186
return model
178187

179188

180-
@case(tags=["regressor", "fitted", "holdout", "cached"])
189+
@case(tags=["regressor", "fitted", "holdout"])
181190
@parametrize("dataset", ["boston"])
182191
def case_regressor_fitted_holdout(
183192
make_cache: Callable[[str], Cache],
@@ -212,7 +221,7 @@ def case_regressor_fitted_holdout(
212221
return model
213222

214223

215-
@case(tags=["regressor", "fitted", "cv", "cached"])
224+
@case(tags=["regressor", "fitted", "cv"])
216225
@parametrize("dataset", ["boston"])
217226
def case_regressor_fitted_cv(
218227
make_cache: Callable[[str], Cache],
@@ -249,7 +258,7 @@ def case_regressor_fitted_cv(
249258
return model
250259

251260

252-
@case(tags=["classifier", "fitted", "no_ensemble", "cached"])
261+
@case(tags=["classifier", "fitted", "no_ensemble"])
253262
@parametrize("dataset", ["iris"])
254263
def case_classifier_fitted_no_ensemble(
255264
make_cache: Callable[[str], Cache],
@@ -258,8 +267,7 @@ def case_classifier_fitted_no_ensemble(
258267
make_automl_classifier: Callable[..., AutoMLClassifier],
259268
make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]],
260269
) -> AutoMLClassifier:
261-
"""Case of a fitted classifier but ensemble was disabled by
262-
not writing models to disk"""
270+
"""Case of a fitted classifier but ensemble was disabled"""
263271
key = f"case_classifier_fitted_no_ensemble_{dataset}"
264272

265273
# This locks the cache for this item while we check, required for pytest-xdist
@@ -270,7 +278,6 @@ def case_classifier_fitted_no_ensemble(
270278
temporary_directory=cache.path("backend"),
271279
delete_tmp_folder_after_terminate=False,
272280
ensemble_class=None,
273-
disable_evaluator_output=True,
274281
)
275282

276283
X, y, Xt, yt = make_sklearn_dataset(name=dataset)
@@ -282,3 +289,85 @@ def case_classifier_fitted_no_ensemble(
282289
model._backend = copy_backend(old=model._backend, new=make_backend())
283290

284291
return model
292+
293+
294+
@case(tags=["classifier", "fitted"])
295+
def case_classifier_fitted_only_dummy(
296+
make_cache: Callable[[str], Cache],
297+
make_backend: Callable[..., Backend],
298+
make_automl_classifier: Callable[..., AutoMLClassifier],
299+
) -> AutoMLClassifier:
300+
"""Case of a fitted classifier but only dummy was found"""
301+
key = "case_classifier_fitted_only_dummy"
302+
303+
# This locks the cache for this item while we check, required for pytest-xdist
304+
305+
with make_cache(key) as cache:
306+
if "model" not in cache:
307+
model = make_automl_classifier(
308+
temporary_directory=cache.path("backend"),
309+
delete_tmp_folder_after_terminate=False,
310+
include={"classifier": ["bernoulli_nb"]}, # Just a meh model
311+
get_trials_callback=stop_at_first,
312+
)
313+
rand = np.random.RandomState(2)
314+
_X = rand.random((100, 50))
315+
_y = rand.randint(0, 2, (100,))
316+
X, Xt, y, yt = sklearn.model_selection.train_test_split(
317+
_X, _y, random_state=1 # Required to ensure dummy is best
318+
)
319+
model.fit(X, y, dataset_name="random")
320+
321+
# We now validate that indeed, the only model is the Dummy
322+
members = list(model.models_.values())
323+
if len(members) != 1 and not isinstance(members[0], MyDummyClassifier):
324+
raise ValueError("Should only have one model, dummy\n", members)
325+
326+
cache.save(model, "model")
327+
328+
model = cache.load("model")
329+
model._backend = copy_backend(old=model._backend, new=make_backend())
330+
331+
return model
332+
333+
334+
@case(tags=["regressor", "fitted"])
335+
def case_regressor_fitted_only_dummy(
336+
make_cache: Callable[[str], Cache],
337+
make_backend: Callable[..., Backend],
338+
make_automl_regressor: Callable[..., AutoMLRegressor],
339+
) -> AutoMLRegressor:
340+
"""Case of a fitted classifier but only dummy was found"""
341+
key = "case_regressor_fitted_only_dummy"
342+
343+
# This locks the cache for this item while we check, required for pytest-xdist
344+
345+
with make_cache(key) as cache:
346+
if "model" not in cache:
347+
model = make_automl_regressor(
348+
temporary_directory=cache.path("backend"),
349+
delete_tmp_folder_after_terminate=False,
350+
include={"regressor": ["k_nearest_neighbors"]}, # Just a meh model
351+
get_trials_callback=stop_at_first,
352+
)
353+
354+
rand = np.random.RandomState(2)
355+
_X = rand.random((100, 50))
356+
_y = rand.random((100,))
357+
358+
X, Xt, y, yt = sklearn.model_selection.train_test_split(
359+
_X, _y, random_state=1 # Required to ensure dummy is best
360+
)
361+
model.fit(X, y, dataset_name="random")
362+
363+
# We now validate that indeed, the only model is the Dummy
364+
members = list(model.models_.values())
365+
if len(members) != 1 and not isinstance(members[0], MyDummyRegressor):
366+
raise ValueError("Should only have one model, dummy\n", members)
367+
368+
cache.save(model, "model")
369+
370+
model = cache.load("model")
371+
model._backend = copy_backend(old=model._backend, new=make_backend())
372+
373+
return model

0 commit comments

Comments
 (0)