Skip to content

Commit 1d4b916

Browse files
Lokiiiiiinish21
andauthored
Feature: Cluster setup for MultiWorkerMirroredStrategy (#415)
* Feature: Cluster setup for MultiWorkerMirroredStrategy * Configuring tests to use the new hyperparameter for MWMS * Black formatted files * fixing failing tests * Removing references to py versions older than py37 * Converting py36 tests to py37 * fix: linting and changed variable name to sagemaker_multi_worker_mirrored_strategy_enabled * fix: feezing protobuf version * fix: renaming MWMS variable name * fix: rename functions for _mwm to _mwms * Revert "fix: feezing protobuf version" This reverts commit c3e6819. * Revert "Converting py36 tests to py37" This reverts commit 86701b4. * Revert "Removing references to py versions older than py37" This reverts commit 718e5c7. * fix: variable name changes for MWMS * fix: renaming training script to train_dummy.py * fix: freezing latest sagemaker toolkit version * trigger ci * fix: adding epochs and steps to failing MWMS test * fix: changing MWMS testcase * fix: logic error in MWMS * fix: logic error in MWMS * fix: Updating MWMS tests to check for log lines * fix: linting * trigger ci * trigger ci Co-authored-by: Nishanth Hegde <[email protected]>
1 parent a58d124 commit 1d4b916

File tree

6 files changed

+200
-22
lines changed

6 files changed

+200
-22
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def read_version():
7474
"Programming Language :: Python :: 3.9",
7575
],
7676
install_requires=[
77-
"sagemaker-training>=4.1.0",
77+
"sagemaker-training>=4.1.3",
7878
"numpy",
7979
"scipy",
8080
"sklearn",

src/sagemaker_tensorflow_container/training.py

+53-8
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,17 @@
2828

2929
SAGEMAKER_PARAMETER_SERVER_ENABLED = "sagemaker_parameter_server_enabled"
3030
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED = "sagemaker_distributed_dataparallel_enabled"
31+
SAGEMAKER_MULTI_WORKER_MIRRORED_STRATEGY_ENABLED = (
32+
"sagemaker_multi_worker_mirrored_strategy_enabled"
33+
)
3134
MODEL_DIR = "/opt/ml/model"
3235

3336

3437
def _is_host_master(hosts, current_host):
3538
return current_host == hosts[0]
3639

3740

38-
def _build_tf_config(hosts, current_host, ps_task=False):
41+
def _build_tf_config_for_ps(hosts, current_host, ps_task=False):
3942
"""Builds a dictionary containing cluster information based on number of hosts and number of
4043
parameter servers.
4144
@@ -85,6 +88,31 @@ def host_addresses(hosts, port=2222):
8588
return tf_config
8689

8790

91+
def _build_tf_config_for_mwms(hosts, current_host):
92+
"""Builds a dictionary containing cluster information based on number of workers
93+
for Multi Worker Mirrored distribution strategy.
94+
95+
Args:
96+
hosts (list[str]): List of host names in the cluster
97+
current_host (str): Current host name
98+
99+
Returns:
100+
dict[str: dict]: A dictionary describing the cluster setup for distributed training.
101+
For more information regarding TF_CONFIG:
102+
https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details
103+
"""
104+
workers = hosts
105+
106+
def host_addresses(hosts, port=8890):
107+
return ["{}:{}".format(host, port) for host in hosts]
108+
109+
tf_config = {"cluster": {}, "environment": "cloud"}
110+
tf_config["cluster"]["worker"] = host_addresses(workers)
111+
tf_config["task"] = {"index": workers.index(current_host), "type": "worker"}
112+
113+
return tf_config
114+
115+
88116
def _run_ps(env, cluster):
89117
logger.info("Running distributed training job with parameter servers")
90118

@@ -134,17 +162,35 @@ def train(env, cmd_args):
134162
Args:
135163
env (sagemaker_training.environment.Environment): Instance of Environment class
136164
"""
137-
parameter_server_enabled = env.additional_framework_parameters.get(
138-
SAGEMAKER_PARAMETER_SERVER_ENABLED, False
165+
parameter_server_enabled = (
166+
env.additional_framework_parameters.get(SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
167+
and len(env.hosts) > 1
168+
)
169+
multi_worker_mirrored_strategy_enabled = env.additional_framework_parameters.get(
170+
SAGEMAKER_MULTI_WORKER_MIRRORED_STRATEGY_ENABLED, False
139171
)
140172
sagemaker_distributed_dataparallel_enabled = env.additional_framework_parameters.get(
141173
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED, False
142174
)
143-
if len(env.hosts) > 1 and parameter_server_enabled:
144175

145-
tf_config = _build_tf_config(hosts=env.hosts, current_host=env.current_host)
176+
env_vars = env.to_env_vars()
177+
178+
# Setup
179+
if parameter_server_enabled:
146180

181+
tf_config = _build_tf_config_for_ps(hosts=env.hosts, current_host=env.current_host)
147182
logger.info("Running distributed training job with parameter servers")
183+
184+
elif multi_worker_mirrored_strategy_enabled:
185+
186+
env_vars["TF_CONFIG"] = json.dumps(
187+
_build_tf_config_for_mwms(hosts=env.hosts, current_host=env.current_host)
188+
)
189+
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")
190+
191+
# Run
192+
if parameter_server_enabled:
193+
148194
logger.info("Launching parameter server process")
149195
_run_ps(env, tf_config["cluster"])
150196
logger.info("Launching worker process")
@@ -168,7 +214,7 @@ def train(env, cmd_args):
168214
uri=env.module_dir,
169215
user_entry_point=env.user_entry_point,
170216
args=cmd_args,
171-
env_vars=env.to_env_vars(),
217+
env_vars=env_vars,
172218
capture_error=True,
173219
runner_type=runner_type,
174220
)
@@ -217,8 +263,7 @@ def _model_dir_with_training_job(model_dir, job_name):
217263

218264

219265
def main():
220-
"""Training entry point
221-
"""
266+
"""Training entry point"""
222267
hyperparameters = environment.read_hyperparameters()
223268
env = environment.Environment(hyperparameters=hyperparameters)
224269

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2017-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
from sagemaker.tensorflow import TensorFlow
18+
from sagemaker.utils import unique_name_from_base
19+
20+
21+
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources")
22+
23+
24+
def test_multi_node(sagemaker_session, instance_type, image_uri, tmpdir, framework_version, capsys):
25+
estimator = TensorFlow(
26+
entry_point=os.path.join(RESOURCE_PATH, "multi_worker_mirrored", "train_dummy.py"),
27+
role="SageMakerRole",
28+
instance_type=instance_type,
29+
instance_count=2,
30+
image_name=image_uri,
31+
framework_version=framework_version,
32+
py_version="py3",
33+
hyperparameters={
34+
"sagemaker_multi_worker_mirrored_strategy_enabled": True,
35+
},
36+
sagemaker_session=sagemaker_session,
37+
)
38+
estimator.fit(job_name=unique_name_from_base("test-tf-mwms"))
39+
captured = capsys.readouterr()
40+
logs = captured.out + captured.err
41+
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
42+
assert "TF_CONFIG=" in logs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Please refer to https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_keras.ipynb
2+
3+
import tensorflow as tf
4+
import numpy as np
5+
import os
6+
import json
7+
8+
9+
def mnist_dataset(batch_size):
10+
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
11+
# The `x` arrays are in uint8 and have values in the [0, 255] range.
12+
# You need to convert them to float32 with values in the [0, 1] range.
13+
x_train = x_train / np.float32(255)
14+
y_train = y_train.astype(np.int64)
15+
train_dataset = tf.data.Dataset.from_tensor_slices(
16+
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
17+
return train_dataset
18+
19+
def build_and_compile_cnn_model():
20+
model = tf.keras.Sequential([
21+
tf.keras.layers.InputLayer(input_shape=(28, 28)),
22+
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
23+
tf.keras.layers.Conv2D(32, 3, activation='relu'),
24+
tf.keras.layers.Flatten(),
25+
tf.keras.layers.Dense(128, activation='relu'),
26+
tf.keras.layers.Dense(10)
27+
])
28+
model.compile(
29+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
30+
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
31+
metrics=['accuracy'])
32+
return model
33+
34+
35+
per_worker_batch_size = 64
36+
tf_config = json.loads(os.environ['TF_CONFIG'])
37+
num_workers = len(tf_config['cluster']['worker'])
38+
39+
strategy = tf.distribute.MultiWorkerMirroredStrategy()
40+
41+
global_batch_size = per_worker_batch_size * num_workers
42+
multi_worker_dataset = mnist_dataset(global_batch_size)
43+
44+
with strategy.scope():
45+
# Model building/compiling need to be within `strategy.scope()`.
46+
multi_worker_model = build_and_compile_cnn_model()
47+
48+
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

test/unit/test_training.py

+43-13
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
"worker": ["{}:2222".format(HOST2)],
3636
"ps": ["{}:2223".format(HOST1), "{}:2223".format(HOST2)],
3737
}
38+
CLUSTER_WITH_MWMS = {"worker": ["{}:8890".format(HOST) for HOST in HOST_LIST]}
39+
3840
MASTER_TASK = {"index": 0, "type": "master"}
3941
WORKER_TASK = {"index": 0, "type": "worker"}
4042
PS_TASK_1 = {"index": 0, "type": "ps"}
@@ -109,7 +111,9 @@ def test_train_horovod(run_module, single_machine_training_env):
109111

110112
@patch("sagemaker_training.entry_point.run")
111113
def test_train_smdataparallel(run_module, single_machine_training_env):
112-
single_machine_training_env.additional_framework_parameters["sagemaker_distributed_dataparallel_enabled"] = True
114+
single_machine_training_env.additional_framework_parameters[
115+
"sagemaker_distributed_dataparallel_enabled"
116+
] = True
113117

114118
training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
115119
run_module.assert_called_with(
@@ -124,7 +128,8 @@ def test_train_smdataparallel(run_module, single_machine_training_env):
124128

125129
@pytest.mark.skip_on_pipeline
126130
@pytest.mark.skipif(
127-
sys.version_info.major != 3, reason="Skip this for python 2 because of dict key order mismatch"
131+
sys.version_info.major != 3,
132+
reason="Skip this for python 2 because of dict key order mismatch",
128133
)
129134
@patch("tensorflow.train.ClusterSpec")
130135
@patch("tensorflow.distribute.Server")
@@ -135,7 +140,11 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
135140
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
136141

137142
cluster_spec.assert_called_with(
138-
{"worker": ["host2:2222"], "master": ["host1:2222"], "ps": ["host1:2223", "host2:2223"]}
143+
{
144+
"worker": ["host2:2222"],
145+
"master": ["host1:2222"],
146+
"ps": ["host1:2223", "host2:2223"],
147+
}
139148
)
140149

141150
tf_server.assert_called_with(
@@ -166,7 +175,8 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
166175

167176
@pytest.mark.skip_on_pipeline
168177
@pytest.mark.skipif(
169-
sys.version_info.major != 3, reason="Skip this for python 2 because of dict key order mismatch"
178+
sys.version_info.major != 3,
179+
reason="Skip this for python 2 because of dict key order mismatch",
170180
)
171181
@patch("tensorflow.train.ClusterSpec")
172182
@patch("tensorflow.distribute.Server")
@@ -179,7 +189,11 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
179189
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
180190

181191
cluster_spec.assert_called_with(
182-
{"worker": ["host2:2222"], "master": ["host1:2222"], "ps": ["host1:2223", "host2:2223"]}
192+
{
193+
"worker": ["host2:2222"],
194+
"master": ["host1:2222"],
195+
"ps": ["host1:2223", "host2:2223"],
196+
}
183197
)
184198

185199
tf_server.assert_called_with(
@@ -226,32 +240,45 @@ def test_train_distributed_no_ps(run, distributed_training_env):
226240
)
227241

228242

229-
def test_build_tf_config():
230-
assert training._build_tf_config(HOST_LIST, HOST1) == {
243+
def test_build_tf_config_for_mwms():
244+
assert training._build_tf_config_for_mwms(HOST_LIST, HOST1) == {
245+
"cluster": CLUSTER_WITH_MWMS,
246+
"environment": "cloud",
247+
"task": {"index": HOST_LIST.index(HOST1), "type": "worker"},
248+
}
249+
assert training._build_tf_config_for_mwms(HOST_LIST, HOST2) == {
250+
"cluster": CLUSTER_WITH_MWMS,
251+
"environment": "cloud",
252+
"task": {"index": HOST_LIST.index(HOST2), "type": "worker"},
253+
}
254+
255+
256+
def test_build_tf_config_for_ps():
257+
assert training._build_tf_config_for_ps(HOST_LIST, HOST1) == {
231258
"cluster": CLUSTER_WITH_PS,
232259
"environment": "cloud",
233260
"task": MASTER_TASK,
234261
}
235-
assert training._build_tf_config(HOST_LIST, HOST1, ps_task=True) == {
262+
assert training._build_tf_config_for_ps(HOST_LIST, HOST1, ps_task=True) == {
236263
"cluster": CLUSTER_WITH_PS,
237264
"environment": "cloud",
238265
"task": PS_TASK_1,
239266
}
240-
assert training._build_tf_config(HOST_LIST, HOST2) == {
267+
assert training._build_tf_config_for_ps(HOST_LIST, HOST2) == {
241268
"cluster": CLUSTER_WITH_PS,
242269
"environment": "cloud",
243270
"task": WORKER_TASK,
244271
}
245-
assert training._build_tf_config(HOST_LIST, HOST2, ps_task=True) == {
272+
assert training._build_tf_config_for_ps(HOST_LIST, HOST2, ps_task=True) == {
246273
"cluster": CLUSTER_WITH_PS,
247274
"environment": "cloud",
248275
"task": PS_TASK_2,
249276
}
250277

251278

252-
def test_build_tf_config_error():
279+
def test_build_tf_config_for_ps_error():
253280
with pytest.raises(ValueError) as error:
254-
training._build_tf_config([HOST1], HOST1, ps_task=True)
281+
training._build_tf_config_for_ps([HOST1], HOST1, ps_task=True)
255282
assert "Cannot have a ps task if there are no parameter servers in the cluster" in str(
256283
error.value
257284
)
@@ -327,7 +354,10 @@ def test_main(
327354
@patch("sagemaker_tensorflow_container.training.train")
328355
@patch("logging.Logger.setLevel")
329356
@patch("sagemaker_training.environment.Environment")
330-
@patch("sagemaker_training.environment.read_hyperparameters", return_value={"model_dir": MODEL_DIR})
357+
@patch(
358+
"sagemaker_training.environment.read_hyperparameters",
359+
return_value={"model_dir": MODEL_DIR},
360+
)
331361
@patch("sagemaker_tensorflow_container.s3_utils.configure")
332362
def test_main_simple_training_model_dir(
333363
configure_s3_env,

0 commit comments

Comments
 (0)