Skip to content

Commit 0f6e75f

Browse files
authored
Add Model.transformer() (#587)
There is no straight forward way to create a batch transform job from a SageMaker Model class. This supports the use case of bringing an existing model artifact to SageMaker and performing batch inference from within the SDK.
1 parent bc17302 commit 0f6e75f

File tree

3 files changed

+54
-38
lines changed

3 files changed

+54
-38
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ CHANGELOG
1414
* doc-fix: Add missing classes to API docs
1515
* doc-fix: Add information on necessary AWS permissions
1616
* bug-fix: Remove PyYAML to let docker-compose install the right version
17+
* enhancement: Add Model.transformer()
1718

1819
1.16.3
1920
======

src/sagemaker/model.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,44 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
249249
if self.predictor_cls:
250250
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
251251

252+
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
253+
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
254+
max_payload=None, tags=None, volume_kms_key=None):
255+
"""Return a ``Transformer`` that uses this Model.
256+
257+
Args:
258+
instance_count (int): Number of EC2 instances to use.
259+
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
260+
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
261+
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
262+
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
263+
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
264+
a default bucket.
265+
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
266+
accept (str): The content type accepted by the endpoint deployed during the transform job.
267+
env (dict): Environment variables to be set for use during the transform job (default: None).
268+
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
269+
each individual transform container at one time.
270+
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
271+
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
272+
the training job are used for the transform job.
273+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
274+
transform jobs. If not specified, the role from the Model will be used.
275+
model_server_workers (int): Optional. The number of worker processes used by the inference server.
276+
If None, server will use one worker per vCPU.
277+
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
278+
compute instance (default: None).
279+
"""
280+
self._create_sagemaker_model(instance_type)
281+
if self.enable_network_isolation():
282+
env = None
283+
284+
return Transformer(self.name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
285+
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
286+
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
287+
env=env, tags=tags, base_transform_job_name=self.name,
288+
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)
289+
252290

253291
SCRIPT_PARAM_NAME = 'sagemaker_program'
254292
DIR_PARAM_NAME = 'sagemaker_submit_directory'
@@ -457,44 +495,6 @@ def _is_marketplace(self):
457495
return True
458496
return False
459497

460-
def transformer(self, instance_count, instance_type, strategy=None, assemble_with=None, output_path=None,
461-
output_kms_key=None, accept=None, env=None, max_concurrent_transforms=None,
462-
max_payload=None, tags=None, volume_kms_key=None):
463-
"""Return a ``Transformer`` that uses this ModelPackage.
464-
465-
Args:
466-
instance_count (int): Number of EC2 instances to use.
467-
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
468-
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
469-
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
470-
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
471-
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
472-
a default bucket.
473-
output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
474-
accept (str): The content type accepted by the endpoint deployed during the transform job.
475-
env (dict): Environment variables to be set for use during the transform job (default: None).
476-
max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
477-
each individual transform container at one time.
478-
max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
479-
tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
480-
the training job are used for the transform job.
481-
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
482-
transform jobs. If not specified, the role from the Model will be used.
483-
model_server_workers (int): Optional. The number of worker processes used by the inference server.
484-
If None, server will use one worker per vCPU.
485-
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
486-
compute instance (default: None).
487-
"""
488-
self._create_sagemaker_model(instance_type)
489-
if self._is_marketplace():
490-
env = None
491-
492-
return Transformer(self.name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
493-
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
494-
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
495-
env=env, tags=tags, base_transform_job_name=self.name,
496-
volume_kms_key=volume_kms_key, sagemaker_session=self.sagemaker_session)
497-
498498
def _create_sagemaker_model(self, *args): # pylint: disable=unused-argument
499499
"""Create a SageMaker Model Entity
500500

tests/unit/test_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,21 @@ def test_model_enable_network_isolation(sagemaker_session):
229229
assert model.enable_network_isolation() is False
230230

231231

232+
@patch('sagemaker.model.Model._create_sagemaker_model', Mock())
233+
def test_model_create_transformer(sagemaker_session):
234+
sagemaker_session.sagemaker_client.describe_model_package = Mock(
235+
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE)
236+
237+
model = DummyFrameworkModel(sagemaker_session=sagemaker_session)
238+
model.name = 'auto-generated-model'
239+
transformer = model.transformer(instance_count=1, instance_type='ml.m4.xlarge',
240+
env={'test': True})
241+
assert isinstance(transformer, sagemaker.transformer.Transformer)
242+
assert transformer.model_name == 'auto-generated-model'
243+
assert transformer.instance_type == 'ml.m4.xlarge'
244+
assert transformer.env == {'test': True}
245+
246+
232247
def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session):
233248
sagemaker_session.sagemaker_client.describe_model_package = Mock(
234249
return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE)

0 commit comments

Comments
 (0)