Skip to content

Commit f822054

Browse files
authored
Make include_cls_metadata default to False for everything except Frameworks (#573)
When a HyperparameterTuner is created using attach(), the tuner needs to pick an Estimator class to use. It looks for the following things (in this order): 1. if the hyperparameters have class metadata (e.g. a path like sagemaker.tensorflow.estimator.TensorFlow) 2. if the image being used corresponds to one of our 1P estimators 3. if nothing is present, in which case it just uses the generic Estimator class This change helps the situation where people are using the generic Estimator class with hyperparameter tuning.
1 parent 0f6e75f commit f822054

File tree

4 files changed

+28
-37
lines changed

4 files changed

+28
-37
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ CHANGELOG
1515
* doc-fix: Add information on necessary AWS permissions
1616
* bug-fix: Remove PyYAML to let docker-compose install the right version
1717
* enhancement: Add Model.transformer()
18+
* bug-fix: HyperparameterTuner: make ``include_cls_metadata`` default to ``False`` for everything except Frameworks
1819

1920
1.16.3
2021
======

README.rst

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -654,15 +654,9 @@ In addition, the ``fit()`` call uses a list of ``RecordSet`` objects instead of
654654
my_tuner.fit([train_records, test_records])
655655
656656
To help attach a previously-started hyperparameter tuning job to a ``HyperparameterTuner`` instance,
657-
``fit()`` adds the module path of the class used to create the tuner to the list of static hyperparameters by default.
658-
If the algorithm you are using cannot handle unknown hyperparameters
659-
(for example, an Amazon SageMaker built-in algorithm that does not have a custom estimator in the Python SDK),
660-
set ``include_cls_metadata`` to ``False`` when you call ``fit``, so that it does not add the module path as a static hyperparameter:
661-
662-
.. code:: python
663-
664-
my_tuner.fit({'train': 's3://my_bucket/my_training_data', 'test': 's3://my_bucket_my_testing_data'},
665-
include_cls_metadata=False)
657+
``fit()`` adds the module path of the class used to create the hyperparameter tuner to the list of static hyperparameters by default.
658+
If you are using your own custom estimator class (i.e. not one provided in this SDK) and want that class to be used when attaching a hyperparamter tuning job,
659+
set ``include_cls_metadata`` to ``True`` when you call ``fit`` to add the module path as static hyperparameters.
666660
667661
There is also an analytics object associated with each ``HyperparameterTuner`` instance that contains useful information about the hyperparameter tuning job.
668662
For example, the ``dataframe`` method gets a pandas dataframe summarizing the associated training jobs:

src/sagemaker/tuner.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from enum import Enum
1919

2020
import sagemaker
21-
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, RecordSet
21+
from sagemaker.amazon.amazon_estimator import RecordSet
2222
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2323
from sagemaker.analytics import HyperparameterTuningJobAnalytics
2424
from sagemaker.estimator import Framework
@@ -219,7 +219,7 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
219219
self.warm_start_config = warm_start_config
220220
self.early_stopping_type = early_stopping_type
221221

222-
def _prepare_for_training(self, job_name=None, include_cls_metadata=True):
222+
def _prepare_for_training(self, job_name=None, include_cls_metadata=False):
223223
if job_name is not None:
224224
self._current_job_name = job_name
225225
else:
@@ -230,14 +230,14 @@ def _prepare_for_training(self, job_name=None, include_cls_metadata=True):
230230
for hyperparameter_name in self._hyperparameter_ranges.keys():
231231
self.static_hyperparameters.pop(hyperparameter_name, None)
232232

233-
# For attach() to know what estimator to use for non-1P algorithms
234-
# (1P algorithms don't accept extra hyperparameters)
235-
if include_cls_metadata and not isinstance(self.estimator, AmazonAlgorithmEstimatorBase):
233+
# For attach() to know what estimator to use for frameworks
234+
# (other algorithms may not accept extra hyperparameters)
235+
if include_cls_metadata or isinstance(self.estimator, Framework):
236236
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = json.dumps(
237237
self.estimator.__class__.__name__)
238238
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(self.estimator.__module__)
239239

240-
def fit(self, inputs=None, job_name=None, include_cls_metadata=True, **kwargs):
240+
def fit(self, inputs=None, job_name=None, include_cls_metadata=False, **kwargs):
241241
"""Start a hyperparameter tuning job.
242242
243243
Args:
@@ -260,10 +260,12 @@ def fit(self, inputs=None, job_name=None, include_cls_metadata=True, **kwargs):
260260
261261
job_name (str): Tuning job name. If not specified, the tuner generates a default job name,
262262
based on the training image name and current timestamp.
263-
include_cls_metadata (bool): Whether or not the hyperparameter tuning job should include information about
264-
the estimator class (default: True). This information is passed as a hyperparameter, so if
265-
the algorithm you are using cannot handle unknown hyperparameters (e.g. an Amazon ML algorithm that
266-
does not have a custom estimator in the Python SDK), then set ``include_cls_metadata`` to ``False``.
263+
include_cls_metadata (bool): Whether or not the hyperparameter tuning job should include
264+
information about the estimator class (default: False). This information is passed
265+
as a hyperparameter, so if the algorithm you are using cannot handle
266+
unknown hyperparameters (e.g. an Amazon SageMaker built-in algorithm that
267+
does not have a custom estimator in the Python SDK), then set
268+
``include_cls_metadata`` to ``False``.
267269
**kwargs: Other arguments needed for training. Please refer to the ``fit()`` method of the associated
268270
estimator to see what other arguments are needed.
269271
"""

tests/unit/test_tuner.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,21 @@
1313
from __future__ import absolute_import
1414

1515
import copy
16-
import json
17-
1816
import os
17+
1918
import pytest
2019
from mock import Mock
2120

2221
from sagemaker import RealTimePredictor
23-
from sagemaker.amazon.pca import PCA
2422
from sagemaker.amazon.amazon_estimator import RecordSet
23+
from sagemaker.amazon.pca import PCA
2524
from sagemaker.estimator import Estimator
26-
from sagemaker.parameter import (ParameterRange, ContinuousParameter,
27-
IntegerParameter, CategoricalParameter)
28-
from sagemaker.tuner import (HyperparameterTuner, _TuningJob, WarmStartConfig, WarmStartTypes,
29-
create_identical_dataset_and_algorithm_tuner,
30-
create_transfer_learning_tuner)
3125
from sagemaker.mxnet import MXNet
26+
from sagemaker.parameter import (CategoricalParameter, ContinuousParameter,
27+
IntegerParameter, ParameterRange)
28+
from sagemaker.tuner import (_TuningJob, create_identical_dataset_and_algorithm_tuner,
29+
create_transfer_learning_tuner, HyperparameterTuner, WarmStartConfig,
30+
WarmStartTypes)
3231

3332
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
3433
MODEL_DATA = "s3://bucket/model.tar.gz"
@@ -157,14 +156,9 @@ def test_prepare_for_training(tuner):
157156

158157
assert tuner._current_job_name.startswith(IMAGE_NAME)
159158

160-
assert len(tuner.static_hyperparameters) == 3
159+
assert len(tuner.static_hyperparameters) == 1
161160
assert tuner.static_hyperparameters['another_one'] == '0'
162161

163-
class_name = json.dumps(tuner.estimator.__class__.__name__)
164-
assert tuner.static_hyperparameters['sagemaker_estimator_class_name'] == class_name
165-
module = json.dumps(tuner.estimator.__module__)
166-
assert tuner.static_hyperparameters['sagemaker_estimator_module'] == module
167-
168162

169163
def test_prepare_for_training_with_amazon_estimator(tuner, sagemaker_session):
170164
tuner.estimator = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
@@ -175,10 +169,10 @@ def test_prepare_for_training_with_amazon_estimator(tuner, sagemaker_session):
175169
assert 'sagemaker_estimator_module' not in tuner.static_hyperparameters
176170

177171

178-
def test_prepare_for_training_dont_include_estimator_cls(tuner):
179-
tuner._prepare_for_training(include_cls_metadata=False)
180-
assert 'sagemaker_estimator_class_name' not in tuner.static_hyperparameters
181-
assert 'sagemaker_estimator_module' not in tuner.static_hyperparameters
172+
def test_prepare_for_training_include_estimator_cls(tuner):
173+
tuner._prepare_for_training(include_cls_metadata=True)
174+
assert 'sagemaker_estimator_class_name' in tuner.static_hyperparameters
175+
assert 'sagemaker_estimator_module' in tuner.static_hyperparameters
182176

183177

184178
def test_prepare_for_training_with_job_name(tuner):

0 commit comments

Comments
 (0)