Skip to content

Commit ab1f758

Browse files
authored
Add update_endpoint support to session class and add optional flag to… (#606)
1 parent f1ba75e commit ab1f758

File tree

11 files changed

+191
-5
lines changed

11 files changed

+191
-5
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ CHANGELOG
44

55

66
1.18.2.dev
7-
======
7+
==========
8+
89
* enhancement: Include SageMaker Notebook Instance version number in boto3 user agent, if available.
10+
* feature: Support for updating existing endpoint
911

1012
1.18.1
1113
======

README.rst

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,37 @@ Here is an end to end example of how to use a SageMaker Estimator:
192192
# Tears down the SageMaker endpoint
193193
mxnet_estimator.delete_endpoint()
194194
195+
196+
Additionally, it is possible to deploy a different endpoint configuration, which links to your model, to an already existing SageMaker endpoint.
197+
This can be done by specifying the existing endpoint name for the ``endpoint_name`` parameter along with the ``update_endpoint`` parameter as ``True`` within your ``deploy()`` call.
198+
For more `information <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.update_endpoint>`__.
199+
200+
.. code:: python
201+
202+
from sagemaker.mxnet import MXNet
203+
204+
# Configure an MXNet Estimator (no training happens yet)
205+
mxnet_estimator = MXNet('train.py',
206+
role='SageMakerRole',
207+
train_instance_type='ml.p2.xlarge',
208+
train_instance_count=1,
209+
framework_version='1.2.1')
210+
211+
# Starts a SageMaker training job and waits until completion.
212+
mxnet_estimator.fit('s3://my_bucket/my_training_data/')
213+
214+
# Deploys the model that was generated by fit() to an existing SageMaker endpoint
215+
mxnet_predictor = mxnet_estimator.deploy(initial_instance_count=1,
216+
instance_type='ml.p2.xlarge',
217+
update_endpoint=True,
218+
endpoint_name='existing-endpoint')
219+
220+
# Serializes data and makes a prediction request to the SageMaker endpoint
221+
response = mxnet_predictor.predict(data)
222+
223+
# Tears down the SageMaker endpoint
224+
mxnet_estimator.delete_endpoint()
225+
195226
Training Metrics
196227
~~~~~~~~~~~~~~~~
197228
The SageMaker Python SDK allows you to specify a name and a regular expression for metrics you want to track for training.

src/sagemaker/estimator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
330330
return estimator
331331

332332
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
333-
use_compiled_model=False, **kwargs):
333+
use_compiled_model=False, update_endpoint=False, **kwargs):
334334
"""Deploy the trained model to an Amazon SageMaker endpoint and return a ``sagemaker.RealTimePredictor`` object.
335335
336336
More information:
@@ -347,6 +347,9 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
347347
endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of
348348
the training job is used.
349349
use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. Default: False.
350+
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
351+
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
352+
corresponding to the previous EndpointConfig. Default: False
350353
**kwargs: Passed to invocation of ``create_model()``. Implementations may customize
351354
``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
352355
For more, see the implementation docs.
@@ -370,7 +373,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
370373
instance_type=instance_type,
371374
initial_instance_count=initial_instance_count,
372375
accelerator_type=accelerator_type,
373-
endpoint_name=endpoint_name)
376+
endpoint_name=endpoint_name,
377+
update_endpoint=update_endpoint)
374378

375379
@property
376380
def model_data(self):

src/sagemaker/local/local_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def create_endpoint(self, EndpointName, EndpointConfigName):
143143
LocalSagemakerClient._endpoints[EndpointName] = endpoint
144144
endpoint.serve()
145145

146+
def update_endpoint(self, EndpointName, EndpointConfigName): # pylint: disable=unused-argument
147+
raise NotImplementedError('Update endpoint name is not supported in local session.')
148+
146149
def delete_endpoint(self, EndpointName):
147150
if EndpointName in LocalSagemakerClient._endpoints:
148151
LocalSagemakerClient._endpoints[EndpointName].stop()

src/sagemaker/model.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def compile(self, target_instance_family, input_shape, output_path, role,
195195
self._is_compiled_model = True
196196
return self
197197

198-
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None, tags=None):
198+
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,
199+
update_endpoint=False, tags=None):
199200
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
200201
201202
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
@@ -217,6 +218,9 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
217218
For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
218219
endpoint_name (str): The name of the endpoint to create (default: None).
219220
If not specified, a unique endpoint name will be created.
221+
update_endpoint (bool): Flag to update the model in an existing Amazon SageMaker endpoint.
222+
If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources
223+
corresponding to the previous EndpointConfig. If False, a new endpoint will be created. Default: False
220224
tags(List[dict[str, str]]): The list of tags to attach to this specific endpoint.
221225
222226
Returns:
@@ -245,7 +249,18 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
245249
self.endpoint_name = self.name
246250
if self._is_compiled_model and not self.endpoint_name.endswith(compiled_model_suffix):
247251
self.endpoint_name += compiled_model_suffix
248-
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
252+
253+
if update_endpoint:
254+
endpoint_config_name = self.sagemaker_session.create_endpoint_config(
255+
name=self.name,
256+
model_name=self.name,
257+
initial_instance_count=initial_instance_count,
258+
instance_type=instance_type,
259+
accelerator_type=accelerator_type)
260+
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
261+
else:
262+
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
263+
249264
if self.predictor_cls:
250265
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
251266

src/sagemaker/session.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,26 @@ def create_endpoint(self, endpoint_name, config_name, wait=True):
750750
self.wait_for_endpoint(endpoint_name)
751751
return endpoint_name
752752

753+
def update_endpoint(self, endpoint_name, endpoint_config_name):
754+
""" Update an Amazon SageMaker ``Endpoint`` according to the endpoint configuration specified in the request
755+
756+
Raise an error if endpoint with endpoint_name does not exist.
757+
758+
Args:
759+
endpoint_name (str): Name of the Amazon SageMaker ``Endpoint`` to update.
760+
endpoint_config_name (str): Name of the Amazon SageMaker endpoint configuration to deploy.
761+
762+
Returns:
763+
str: Name of the Amazon SageMaker ``Endpoint`` being updated.
764+
"""
765+
if not _deployment_entity_exists(lambda: self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name)):
766+
raise ValueError('Endpoint with name "{}" does not exist; please use an existing endpoint name'
767+
.format(endpoint_name))
768+
769+
self.sagemaker_client.update_endpoint(EndpointName=endpoint_name,
770+
EndpointConfigName=endpoint_config_name)
771+
return endpoint_name
772+
753773
def delete_endpoint(self, endpoint_name):
754774
"""Delete an Amazon SageMaker ``Endpoint``.
755775

tests/integ/test_mxnet_train.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,48 @@ def test_deploy_model(mxnet_training_job, sagemaker_session):
7272
predictor.predict(data)
7373

7474

75+
def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session):
76+
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
77+
78+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
79+
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
80+
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
81+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
82+
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
83+
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session)
84+
model.deploy(1, 'ml.t2.medium', endpoint_name=endpoint_name)
85+
old_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
86+
old_config_name = old_endpoint['EndpointConfigName']
87+
88+
model.deploy(1, 'ml.m4.xlarge', update_endpoint=True, endpoint_name=endpoint_name)
89+
new_endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)['ProductionVariants']
90+
new_production_variants = new_endpoint['ProductionVariants']
91+
new_config_name = new_endpoint['EndpointConfigName']
92+
93+
assert old_config_name != new_config_name
94+
assert new_production_variants['InstanceType'] == 'ml.m4.xlarge'
95+
assert new_production_variants['InitialInstanceCount'] == 1
96+
assert new_production_variants['AcceleratorType'] is None
97+
98+
99+
def test_deploy_model_with_update_non_existing_endpoint(mxnet_training_job, sagemaker_session):
100+
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
101+
expected_error_message = 'Endpoint with name "{}" does not exist; ' \
102+
'please use an existing endpoint name'.format(endpoint_name)
103+
104+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
105+
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
106+
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
107+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
108+
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
109+
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session)
110+
model.deploy(1, 'ml.t2.medium', endpoint_name=endpoint_name)
111+
sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
112+
113+
with pytest.raises(ValueError, message=expected_error_message):
114+
model.deploy(1, 'ml.m4.xlarge', update_endpoint=True, endpoint_name='non-existing-endpoint')
115+
116+
75117
@pytest.mark.continuous_testing
76118
@pytest.mark.regional_testing
77119
@pytest.mark.skipif(tests.integ.test_region() not in tests.integ.EI_SUPPORTED_REGIONS,

tests/unit/test_estimator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,21 @@ def test_generic_deploy_accelerator_type(sagemaker_session):
990990
assert args[1][0]['InstanceType'] == INSTANCE_TYPE
991991

992992

993+
def test_deploy_with_update_endpoint(sagemaker_session):
994+
estimator = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
995+
sagemaker_session=sagemaker_session)
996+
estimator.set_hyperparameters(**HYPERPARAMS)
997+
estimator.fit({'train': 's3://bucket/training-prefix'})
998+
endpoint_name = 'endpoint-name'
999+
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE, endpoint_name=endpoint_name, update_endpoint=True)
1000+
1001+
update_endpoint_args = sagemaker_session.update_endpoint.call_args[0]
1002+
assert update_endpoint_args[0] == endpoint_name
1003+
assert update_endpoint_args[1].startWith(IMAGE_NAME)
1004+
1005+
sagemaker_session.create_endpoint.assert_not_called()
1006+
1007+
9931008
@patch('sagemaker.estimator.LocalSession')
9941009
@patch('sagemaker.estimator.Session')
9951010
def test_local_mode(session_class, local_session_class):

tests/unit/test_local_session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,16 @@ def test_create_endpoint(describe_model, describe_endpoint_config, request, *arg
310310
assert 'my-endpoint' in sagemaker.local.local_session.LocalSagemakerClient._endpoints
311311

312312

313+
@patch('sagemaker.local.local_session.LocalSession')
314+
def test_update_endpoint(LocalSession):
315+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
316+
endpoint_name = 'my-endpoint'
317+
endpoint_config = 'my-endpoint-config'
318+
expected_error_message = 'Update endpoint name is not supported in local session.'
319+
with pytest.raises(NotImplementedError, message=expected_error_message):
320+
local_sagemaker_client.update_endpoint(endpoint_name, endpoint_config)
321+
322+
313323
@patch('sagemaker.local.image._SageMakerContainer.serve')
314324
@patch('urllib3.PoolManager.request')
315325
def test_serve_endpoint_with_correct_accelerator(request, *args):

tests/unit/test_model.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,33 @@ def test_deploy_creates_correct_session(local_session, session, tmpdir):
224224
assert model.sagemaker_session == session.return_value
225225

226226

227+
@patch('sagemaker.fw_utils.tar_and_upload_dir', MagicMock())
228+
def test_deploy_update_endpoint(sagemaker_session, tmpdir):
229+
model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir)
230+
endpoint_name = 'endpoint-name'
231+
model.deploy(instance_type=INSTANCE_TYPE,
232+
initial_instance_count=1,
233+
endpoint_name=endpoint_name,
234+
update_endpoint=True,
235+
accelerator_type=ACCELERATOR_TYPE)
236+
sagemaker_session.create_endpoint_config.assert_called_with(
237+
name=model.name,
238+
model_name=model.name,
239+
initial_instance_count=INSTANCE_COUNT,
240+
instance_type=INSTANCE_TYPE,
241+
accelerator_type=ACCELERATOR_TYPE
242+
)
243+
config_name = sagemaker_session.create_endpoint_config(
244+
name=model.name,
245+
model_name=model.name,
246+
initial_instance_count=INSTANCE_COUNT,
247+
instance_type=INSTANCE_TYPE,
248+
accelerator_type=ACCELERATOR_TYPE
249+
)
250+
sagemaker_session.update_endpoint.assert_called_with(endpoint_name, config_name)
251+
sagemaker_session.create_endpoint.assert_not_called()
252+
253+
227254
def test_model_enable_network_isolation(sagemaker_session):
228255
model = DummyFrameworkModel(sagemaker_session=sagemaker_session)
229256
assert model.enable_network_isolation() is False

tests/unit/test_session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,23 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi
920920
Tags=tags)
921921

922922

923+
def test_update_endpoint_succeed(sagemaker_session):
924+
sagemaker_session.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointStatus': 'InService'})
925+
endpoint_name = "some-endpoint"
926+
endpoint_config = "some-endpoint-config"
927+
returned_endpoint_name = sagemaker_session.update_endpoint(endpoint_name, endpoint_config)
928+
assert returned_endpoint_name == endpoint_name
929+
930+
931+
def test_update_endpoint_non_existing_endpoint(sagemaker_session):
932+
error = ClientError({'Error': {'Code': 'ValidationException', 'Message': 'Could not find entity'}}, 'foo')
933+
expected_error_message = 'Endpoint with name "non-existing-endpoint" does not exist; ' \
934+
'please use an existing endpoint name'
935+
sagemaker_session.sagemaker_client.describe_endpoint = Mock(side_effect=error)
936+
with pytest.raises(ValueError, message=expected_error_message):
937+
sagemaker_session.update_endpoint("non-existing-endpoint", "non-existing-config")
938+
939+
923940
@patch('time.sleep')
924941
def test_wait_for_tuning_job(sleep, sagemaker_session):
925942
hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'}

0 commit comments

Comments
 (0)