Skip to content

Commit 1d3713e

Browse files
feature: support heterogeneous cluster for training
Co-authored-by: Navin Soni <[email protected]>
1 parent e21b457 commit 1d3713e

File tree

8 files changed

+220
-26
lines changed

8 files changed

+220
-26
lines changed

src/sagemaker/algorithm.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,19 @@ def __init__(
147147
self.algorithm_arn = algorithm_arn
148148
super(AlgorithmEstimator, self).__init__(
149149
role,
150-
instance_count,
151-
instance_type,
152-
volume_size,
153-
volume_kms_key,
154-
max_run,
155-
input_mode,
156-
output_path,
157-
output_kms_key,
158-
base_job_name,
159-
sagemaker_session,
160-
tags,
161-
subnets,
162-
security_group_ids,
150+
instance_count=instance_count,
151+
instance_type=instance_type,
152+
volume_size=volume_size,
153+
volume_kms_key=volume_kms_key,
154+
max_run=max_run,
155+
input_mode=input_mode,
156+
output_path=output_path,
157+
output_kms_key=output_kms_key,
158+
base_job_name=base_job_name,
159+
sagemaker_session=sagemaker_session,
160+
tags=tags,
161+
subnets=subnets,
162+
security_group_ids=security_group_ids,
163163
model_uri=model_uri,
164164
model_channel_name=model_channel_name,
165165
metric_definitions=metric_definitions,

src/sagemaker/estimator.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def __init__(
145145
code_location: Optional[str] = None,
146146
entry_point: Optional[str] = None,
147147
dependencies: Optional[List[Union[str]]] = None,
148+
instance_groups=None,
148149
**kwargs,
149150
):
150151
"""Initialize an ``EstimatorBase`` instance.
@@ -156,9 +157,10 @@ def __init__(
156157
artifacts. After the endpoint is created, the inference code
157158
might use the IAM role, if it needs to access an AWS resource.
158159
instance_count (int): Number of Amazon EC2 instances to use
159-
for training.
160+
for training. Required if instance_groups is not set.
160161
instance_type (str): Type of EC2 instance to use for training,
161-
for example, 'ml.c4.xlarge'.
162+
for example, 'ml.c4.xlarge'. Required if instance_groups is
163+
not set.
162164
volume_size (int): Size in GB of the EBS volume to use for
163165
storing input data during training (default: 30). Must be large
164166
enough to store training data if File Mode is used (which is the
@@ -424,7 +426,10 @@ def __init__(
424426
>>> |------ virtual-env
425427
426428
This is not supported with "local code" in Local Mode.
427-
429+
instance_groups (list[InstanceGroup]): Optional. List of InstanceGroup
430+
for specifying different instance groups for heterogeneous cluster.
431+
For example: [sagemaker.InstanceGroup('worker','ml.p3dn.24xlarge',64),
432+
sagemaker.InstanceGroup('server','ml.c5n.18xlarge',64)]
428433
"""
429434
instance_count = renamed_kwargs(
430435
"train_instance_count", "instance_count", instance_count, kwargs
@@ -442,12 +447,10 @@ def __init__(
442447
"train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs
443448
)
444449

445-
if instance_count is None or instance_type is None:
446-
raise ValueError("Both instance_count and instance_type are required.")
447-
448450
self.role = role
449451
self.instance_count = instance_count
450452
self.instance_type = instance_type
453+
self.instance_groups = instance_groups
451454
self.volume_size = volume_size
452455
self.volume_kms_key = volume_kms_key
453456
self.max_run = max_run
@@ -2103,6 +2106,7 @@ def __init__(
21032106
code_location: Optional[str] = None,
21042107
entry_point: Optional[str] = None,
21052108
dependencies: Optional[List[str]] = None,
2109+
instance_groups=None,
21062110
**kwargs,
21072111
):
21082112
"""Initialize an ``Estimator`` instance.
@@ -2115,9 +2119,10 @@ def __init__(
21152119
artifacts. After the endpoint is created, the inference code
21162120
might use the IAM role, if it needs to access an AWS resource.
21172121
instance_count (int): Number of Amazon EC2 instances to use
2118-
for training.
2122+
for training. Required if instance_groups is not set.
21192123
instance_type (str): Type of EC2 instance to use for training,
2120-
for example, 'ml.c4.xlarge'.
2124+
for example, 'ml.c4.xlarge'. Required if instance_groups is
2125+
not set.
21212126
volume_size (int): Size in GB of the EBS volume to use for
21222127
storing input data during training (default: 30). Must be large
21232128
enough to store training data if File Mode is used (which is the
@@ -2379,13 +2384,18 @@ def __init__(
23792384
>>> |------ virtual-env
23802385
23812386
This is not supported with "local code" in Local Mode.
2387+
instance_groups (list[InstanceGroup]): Optional. List of InstanceGroup
2388+
for specifying different instance groups for heterogeneous cluster.
2389+
For example: [sagemaker.InstanceGroup('worker','ml.p3dn.24xlarge',64),
2390+
sagemaker.InstanceGroup('server','ml.c5n.18xlarge',64)]
23822391
"""
23832392
self.image_uri = image_uri
23842393
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
23852394
super(Estimator, self).__init__(
23862395
role,
23872396
instance_count,
23882397
instance_type,
2398+
instance_groups,
23892399
volume_size,
23902400
volume_kms_key,
23912401
max_run,

src/sagemaker/inputs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
content_type=None,
3636
record_wrapping=None,
3737
s3_data_type="S3Prefix",
38+
instance_groups=None,
3839
input_mode=None,
3940
attribute_names=None,
4041
target_attribute_name=None,
@@ -60,6 +61,8 @@ def __init__(
6061
listing the S3 data to train on. Both the ManifestFile and
6162
AugmentedManifestFile formats are described in the SageMaker API documentation:
6263
https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html
64+
instance_groups (list[str]): Optional. List of InstanceGroupNames to send data to
65+
(default: None). By default, data will be sent to all groups.
6366
input_mode (str): Optional override for this channel's input mode (default: None).
6467
By default, channels will use the input mode defined on
6568
``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore
@@ -97,6 +100,8 @@ def __init__(
97100
self.config["ContentType"] = content_type
98101
if record_wrapping is not None:
99102
self.config["RecordWrapperType"] = record_wrapping
103+
if instance_groups is not None:
104+
self.config["DataSource"]["S3DataSource"]["InstanceGroupNames"] = instance_groups
100105
if input_mode is not None:
101106
self.config["InputMode"] = input_mode
102107
if attribute_names is not None:

src/sagemaker/instance_group.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This file defines instance group for heterogeneous cluster."""
14+
from __future__ import absolute_import
15+
16+
17+
class InstanceGroup(object):
18+
"""Accepts instance group parameters for conversion to request dict.
19+
20+
The `_to_request_dict` provides a method to turn the parameters into a dict.
21+
"""
22+
23+
def __init__(
24+
self,
25+
instance_group_name=None,
26+
instance_type=None,
27+
instance_count=None,
28+
):
29+
"""Initialize a ``InstanceGroup`` instance.
30+
31+
InstanceGroup accepts instance group parameters and provides a method to turn
32+
these parameters into a dictionary.
33+
34+
Args:
35+
instance_group_name (str): Name of the instance group.
36+
instance_type (str): Type of EC2 instance to use in the instance group,
37+
for example, 'ml.c4.xlarge'.
38+
instance_count (int): Number of EC2 instances to use in the instance group.
39+
"""
40+
self.instance_group_name = instance_group_name
41+
self.instance_type = instance_type
42+
self.instance_count = instance_count
43+
44+
def _to_request_dict(self):
45+
"""Generates a request dictionary using the parameters provided to the class."""
46+
return {
47+
"InstanceGroupName": self.instance_group_name,
48+
"InstanceType": self.instance_type,
49+
"InstanceCount": self.instance_count,
50+
}

src/sagemaker/job.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True):
7474
resource_config = _Job._prepare_resource_config(
7575
estimator.instance_count,
7676
estimator.instance_type,
77+
estimator.instance_groups,
7778
estimator.volume_size,
7879
estimator.volume_kms_key,
7980
)
@@ -283,15 +284,31 @@ def _prepare_output_config(s3_path, kms_key_id):
283284
return config
284285

285286
@staticmethod
286-
def _prepare_resource_config(instance_count, instance_type, volume_size, volume_kms_key):
287+
def _prepare_resource_config(
288+
instance_count, instance_type, instance_groups, volume_size, volume_kms_key
289+
):
287290
"""Placeholder docstring"""
288291
resource_config = {
289-
"InstanceCount": instance_count,
290-
"InstanceType": instance_type,
291292
"VolumeSizeInGB": volume_size,
292293
}
293294
if volume_kms_key is not None:
294295
resource_config["VolumeKmsKeyId"] = volume_kms_key
296+
if instance_groups is not None:
297+
if instance_count is not None or instance_type is not None:
298+
raise ValueError(
299+
"instance_count and instance_type cannot be set when instance_groups is set"
300+
)
301+
302+
resource_config["InstanceGroups"] = [
303+
group._to_request_dict() for group in instance_groups
304+
]
305+
else:
306+
if instance_count is None or instance_type is None:
307+
raise ValueError(
308+
"instance_count and instance_type must be set if instance_groups is not set"
309+
)
310+
resource_config["InstanceCount"] = instance_count
311+
resource_config["InstanceType"] = instance_type
295312

296313
return resource_config
297314

tests/unit/test_estimator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob
4444
from sagemaker.fw_utils import PROFILER_UNSUPPORTED_REGIONS
4545
from sagemaker.inputs import ShuffleConfig
46+
from sagemaker.instance_group import InstanceGroup
4647
from sagemaker.model import FrameworkModel
4748
from sagemaker.mxnet.estimator import MXNet
4849
from sagemaker.predictor import Predictor
@@ -323,6 +324,31 @@ def test_framework_all_init_args(sagemaker_session):
323324
}
324325

325326

327+
def test_framework_with_heterogeneous_cluster(sagemaker_session):
328+
f = DummyFramework(
329+
entry_point=SCRIPT_PATH,
330+
role=ROLE,
331+
sagemaker_session=sagemaker_session,
332+
instance_groups=[
333+
InstanceGroup("group1", "ml.c4.xlarge", 1),
334+
InstanceGroup("group2", "ml.m4.xlarge", 2),
335+
],
336+
)
337+
f.fit("s3://mydata")
338+
sagemaker_session.train.assert_called_once()
339+
_, args = sagemaker_session.train.call_args
340+
assert args["resource_config"]["InstanceGroups"][0] == {
341+
"InstanceGroupName": "group1",
342+
"InstanceCount": 1,
343+
"InstanceType": "ml.c4.xlarge",
344+
}
345+
assert args["resource_config"]["InstanceGroups"][1] == {
346+
"InstanceGroupName": "group2",
347+
"InstanceCount": 2,
348+
"InstanceType": "ml.m4.xlarge",
349+
}
350+
351+
326352
def test_framework_with_debugger_and_built_in_rule(sagemaker_session):
327353
debugger_built_in_rule_with_custom_args = Rule.sagemaker(
328354
base_config=rule_configs.stalled_training_rule(),

tests/unit/test_inputs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,44 @@ def test_training_input_all_arguments():
6767
assert result.config == expected
6868

6969

70+
def test_training_input_all_arguments_heterogeneous_cluster():
71+
prefix = "pre"
72+
distribution = "FullyReplicated"
73+
compression = "Gzip"
74+
content_type = "text/csv"
75+
record_wrapping = "RecordIO"
76+
s3_data_type = "Manifestfile"
77+
instance_groups = ["data-server"]
78+
input_mode = "Pipe"
79+
result = TrainingInput(
80+
s3_data=prefix,
81+
distribution=distribution,
82+
compression=compression,
83+
input_mode=input_mode,
84+
content_type=content_type,
85+
record_wrapping=record_wrapping,
86+
s3_data_type=s3_data_type,
87+
instance_groups=instance_groups,
88+
)
89+
90+
expected = {
91+
"DataSource": {
92+
"S3DataSource": {
93+
"S3DataDistributionType": distribution,
94+
"S3DataType": s3_data_type,
95+
"S3Uri": prefix,
96+
"InstanceGroupNames": instance_groups,
97+
}
98+
},
99+
"CompressionType": compression,
100+
"ContentType": content_type,
101+
"RecordWrapperType": record_wrapping,
102+
"InputMode": input_mode,
103+
}
104+
105+
assert result.config == expected
106+
107+
70108
def test_file_system_input_default_access_mode():
71109
file_system_id = "fs-0a48d2a1"
72110
file_system_type = "EFS"

0 commit comments

Comments
 (0)