Skip to content

Commit 92e232d

Browse files
authored
infra: add single-instance, multi-process Horovod test for local GPU (#390)
1 parent 6a2903d commit 92e232d

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

test/integration/local/test_horovod.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,38 @@
2323
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
2424

2525

26+
@pytest.mark.skip_cpu
27+
@pytest.mark.skip_generic
28+
def test_distributed_training_horovod_gpu(
29+
sagemaker_local_session, image_uri, tmpdir, framework_version
30+
):
31+
_test_distributed_training_horovod(
32+
1, 2, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local_gpu'
33+
)
34+
35+
2636
@pytest.mark.skip_gpu
2737
@pytest.mark.skip_generic
28-
@pytest.mark.parametrize('instances, processes', [
29-
[1, 2],
30-
(2, 1),
31-
(2, 2),
32-
(5, 2)])
33-
def test_distributed_training_horovod_basic(instances,
34-
processes,
35-
sagemaker_local_session,
36-
image_uri,
37-
tmpdir,
38-
framework_version):
38+
@pytest.mark.parametrize(
39+
'instances, processes', [(1, 2), (2, 1), (2, 2), (5, 2)]
40+
)
41+
def test_distributed_training_horovod_cpu(
42+
instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version
43+
):
44+
_test_distributed_training_horovod(
45+
instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local'
46+
)
47+
48+
49+
def _test_distributed_training_horovod(
50+
instances, processes, session, image_uri, tmpdir, framework_version, instance_type
51+
):
3952
output_path = 'file://%s' % tmpdir
4053
estimator = TensorFlow(
4154
entry_point=os.path.join(RESOURCE_PATH, 'hvdbasic', 'train_hvd_basic.py'),
4255
role='SageMakerRole',
43-
train_instance_type='local',
44-
sagemaker_session=sagemaker_local_session,
56+
train_instance_type=instance_type,
57+
sagemaker_session=session,
4558
train_instance_count=instances,
4659
image_name=image_uri,
4760
output_path=output_path,

0 commit comments

Comments
 (0)