Skip to content

Commit 1dc4cb0

Browse files
yl-tojessicazhu3
authored andcommitted
feature: heterogeneous cluster set up in distribution config
Co-authored-by: Jessica Zhu <[email protected]>
1 parent 1d3713e commit 1dc4cb0

File tree

7 files changed

+298
-36
lines changed

7 files changed

+298
-36
lines changed

src/sagemaker/estimator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3111,6 +3111,13 @@ def _distribution_configuration(self, distribution):
31113111
"""
31123112
distribution_config = {}
31133113

3114+
mpi_enabled = False
3115+
smdataparallel_enabled = False
3116+
if "instance_groups" in distribution:
3117+
distribution_config["sagemaker_distribution_instance_groups"] = distribution[
3118+
"instance_groups"
3119+
]
3120+
31143121
if "parameter_server" in distribution:
31153122
ps_enabled = distribution.get("parameter_server").get("enabled", False)
31163123
distribution_config[self.LAUNCH_PS_ENV_NAME] = ps_enabled
@@ -3146,6 +3153,13 @@ def _distribution_configuration(self, distribution):
31463153
"dataparallel"
31473154
].get("custom_mpi_options", "")
31483155

3156+
if not (mpi_enabled or smdataparallel_enabled) and distribution_config.get(
3157+
"sagemaker_distribution_instance_groups"
3158+
) not in [None, []]:
3159+
raise ValueError(
3160+
"Don't set training instance groups while no distribution strategies enabled!"
3161+
)
3162+
31493163
return distribution_config
31503164

31513165

src/sagemaker/fw_utils.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import sagemaker.utils
2828
from sagemaker.workflow import is_pipeline_variable
2929

30-
from sagemaker.deprecations import renamed_warning
30+
from sagemaker.deprecations import renamed_warning, renamed_kwargs
3131

3232
logger = logging.getLogger(__name__)
3333

@@ -600,6 +600,106 @@ def _validate_smdataparallel_args(
600600
raise ValueError(err_msg)
601601

602602

603+
def validate_distribution(
604+
distribution, instance_groups, framework_name, framework_version, py_version, image_uri, kwargs
605+
):
606+
"""Check if distribution strategy is correctly invoked by the user.
607+
608+
Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up.
609+
Validate if the user requested strategy is supported.
610+
611+
Args:
612+
distribution (dict): A dictionary with information to enable distributed training.
613+
(Defaults to None if distributed training is not enabled.) For example:
614+
615+
.. code:: python
616+
617+
{
618+
"smdistributed": {
619+
"dataparallel": {
620+
"enabled": True
621+
}
622+
}
623+
}
624+
instance_groups ([InstanceGroup]): A list contains instance groups used for training.
625+
framework_name (str): A string representing the name of framework selected.
626+
framework_version (str): A string representing the framework version selected.
627+
py_version (str): A string representing the python version selected.
628+
image_uri (str): A string representing a Docker image URI.
629+
kwargs(dict): Additional kwargs passed to this function
630+
631+
Returns:
632+
distribution(dict): updated dictionary with validated information
633+
to enable distributed training.
634+
635+
Raises:
636+
ValueError: if distribution dictionary isn't correctly formatted or
637+
multiple strategies are requested simultaneously or
638+
an unsupported strategy is requested or
639+
strategy-specific inputs are incorrect/unsupported or
640+
heterogeneous cluster set up is incorrect
641+
"""
642+
train_instance_groups = distribution.get("instance_groups", [])
643+
if instance_groups is None:
644+
if len(train_instance_groups) >= 1:
645+
# if estimator's instance_groups is not defined but
646+
# train_instance_groups are specified in distribution
647+
raise ValueError("Instance groups not specified in the estimator !")
648+
else:
649+
if len(train_instance_groups) > len(instance_groups):
650+
# if train_instance_groups in distribution are more than estimator's instance_groups
651+
raise ValueError("Train instance groups oversubscribed !")
652+
if len(instance_groups) == 1 and len(train_instance_groups) == 0:
653+
# if just one instance_group but it is not specified in distribution, we set it for user
654+
train_instance_groups = instance_groups
655+
elif len(instance_groups) > 1 and len(train_instance_groups) != 1:
656+
# currently we just support one train instance group
657+
raise ValueError("Distribution should only contain one instance group name !")
658+
659+
if len(train_instance_groups) != 0:
660+
# in this case, we are handling a heterogeneous cluster training job
661+
instance_group_names = []
662+
for train_instance_group in train_instance_groups:
663+
# in future version we will support multiple train_instance_groups, so use loop here
664+
if train_instance_group not in instance_groups:
665+
# check if train instance groups belongs to what user defined in estimator set up
666+
raise ValueError(
667+
f"Invalid training instance group {train_instance_group.instance_group_name} !"
668+
)
669+
instance_type = train_instance_group.instance_type
670+
validate_smdistributed(
671+
instance_type=instance_type,
672+
framework_name=framework_name,
673+
framework_version=framework_version,
674+
py_version=py_version,
675+
distribution=distribution,
676+
image_uri=image_uri,
677+
)
678+
warn_if_parameter_server_with_multi_gpu(
679+
training_instance_type=instance_type, distribution=distribution
680+
)
681+
# get instance group names
682+
instance_group_names.append(train_instance_group.instance_group_name)
683+
distribution["instance_groups"] = instance_group_names
684+
else:
685+
# in this case, we are handling a normal training job (without heterogeneous cluster)
686+
instance_type = renamed_kwargs(
687+
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
688+
)
689+
validate_smdistributed(
690+
instance_type=instance_type,
691+
framework_name=framework_name,
692+
framework_version=framework_version,
693+
py_version=py_version,
694+
distribution=distribution,
695+
image_uri=image_uri,
696+
)
697+
warn_if_parameter_server_with_multi_gpu(
698+
training_instance_type=instance_type, distribution=distribution
699+
)
700+
return distribution
701+
702+
603703
def python_deprecation_warning(framework, latest_supported_version):
604704
"""Placeholder docstring"""
605705
return PYTHON_2_DEPRECATION_WARNING.format(

src/sagemaker/pytorch/estimator.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717

1818
from packaging.version import Version
1919

20-
from sagemaker.deprecations import renamed_kwargs
2120
from sagemaker.estimator import Framework, EstimatorBase
2221
from sagemaker.fw_utils import (
2322
framework_name_from_image,
2423
framework_version_from_tag,
2524
python_deprecation_warning,
2625
validate_version_or_image_args,
27-
warn_if_parameter_server_with_multi_gpu,
28-
validate_smdistributed,
26+
validate_distribution,
2927
)
3028
from sagemaker.pytorch import defaults
3129
from sagemaker.pytorch.model import PyTorchModel
@@ -196,24 +194,6 @@ def __init__(
196194
self.framework_version = framework_version
197195
self.py_version = py_version
198196

199-
if distribution is not None:
200-
instance_type = renamed_kwargs(
201-
"train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs
202-
)
203-
204-
validate_smdistributed(
205-
instance_type=instance_type,
206-
framework_name=self._framework_name,
207-
framework_version=framework_version,
208-
py_version=py_version,
209-
distribution=distribution,
210-
image_uri=image_uri,
211-
)
212-
213-
warn_if_parameter_server_with_multi_gpu(
214-
training_instance_type=instance_type, distribution=distribution
215-
)
216-
217197
if "enable_sagemaker_metrics" not in kwargs:
218198
# enable sagemaker metrics for PT v1.3 or greater:
219199
if self.framework_version and Version(self.framework_version) >= Version("1.3"):
@@ -222,6 +202,17 @@ def __init__(
222202
super(PyTorch, self).__init__(
223203
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
224204
)
205+
if distribution is not None:
206+
distribution = validate_distribution(
207+
distribution,
208+
self.instance_groups,
209+
self._framework_name,
210+
framework_version,
211+
py_version,
212+
image_uri,
213+
kwargs,
214+
)
215+
225216
self.distribution = distribution or {}
226217

227218
def hyperparameters(self):

src/sagemaker/tensorflow/estimator.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -183,25 +183,22 @@ def __init__(
183183
self.py_version = py_version
184184
self.instance_type = instance_type
185185

186-
if distribution is not None:
187-
fw.warn_if_parameter_server_with_multi_gpu(
188-
training_instance_type=instance_type, distribution=distribution
189-
)
190-
fw.validate_smdistributed(
191-
instance_type=instance_type,
192-
framework_name=self._framework_name,
193-
framework_version=framework_version,
194-
py_version=py_version,
195-
distribution=distribution,
196-
image_uri=image_uri,
197-
)
198-
199186
if "enable_sagemaker_metrics" not in kwargs:
200187
# enable sagemaker metrics for TF v1.15 or greater:
201188
if framework_version and version.Version(framework_version) >= version.Version("1.15"):
202189
kwargs["enable_sagemaker_metrics"] = True
203190

204191
super(TensorFlow, self).__init__(image_uri=image_uri, **kwargs)
192+
if distribution is not None:
193+
distribution = fw.validate_distribution(
194+
distribution,
195+
self.instance_groups,
196+
self._framework_name,
197+
framework_version,
198+
py_version,
199+
image_uri,
200+
kwargs,
201+
)
205202
self.model_dir = model_dir
206203
self.distribution = distribution or {}
207204

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from sagemaker.estimator import _TrainingJob
2424
from sagemaker.tensorflow import TensorFlow
25+
from sagemaker.instance_group import InstanceGroup
2526
from tests.unit import DATA_DIR
2627

2728
SCRIPT_FILE = "dummy_script.py"
@@ -538,3 +539,24 @@ def test_custom_image(sagemaker_session):
538539
custom_image = "tensorflow:latest"
539540
tf = _build_tf(sagemaker_session, image_uri=custom_image)
540541
assert custom_image == tf.training_image_uri()
542+
543+
544+
def test_tf_heterogeneous_cluster_distribution_config(
545+
sagemaker_session, tensorflow_training_version, tensorflow_training_py_version
546+
):
547+
if version.Version(tensorflow_training_version) < version.Version("2.0"):
548+
pytest.skip("This test is for TF 2.0 and higher.")
549+
550+
training_group = InstanceGroup("train_group", "ml.c4.xlarge", 1)
551+
expected_return = {"mpi": {"enabled": True}, "instance_groups": ["train_group"]}
552+
tf = _build_tf(
553+
sagemaker_session,
554+
framework_version=tensorflow_training_version,
555+
py_version=tensorflow_training_py_version,
556+
instance_groups=[training_group],
557+
distribution={
558+
"mpi": {"enabled": True},
559+
"instance_groups": [training_group],
560+
},
561+
)
562+
assert tf.distribution == expected_return

0 commit comments

Comments
 (0)