|
23 | 23 | RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
|
24 | 24 |
|
25 | 25 |
|
| 26 | +@pytest.mark.skip_cpu |
| 27 | +@pytest.mark.skip_generic |
| 28 | +def test_distributed_training_horovod_gpu( |
| 29 | + sagemaker_local_session, image_uri, tmpdir, framework_version |
| 30 | +): |
| 31 | + _test_distributed_training_horovod( |
| 32 | + 1, 2, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local_gpu' |
| 33 | + ) |
| 34 | + |
| 35 | + |
26 | 36 | @pytest.mark.skip_gpu
|
27 | 37 | @pytest.mark.skip_generic
|
28 |
| -@pytest.mark.parametrize('instances, processes', [ |
29 |
| - [1, 2], |
30 |
| - (2, 1), |
31 |
| - (2, 2), |
32 |
| - (5, 2)]) |
33 |
| -def test_distributed_training_horovod_basic(instances, |
34 |
| - processes, |
35 |
| - sagemaker_local_session, |
36 |
| - image_uri, |
37 |
| - tmpdir, |
38 |
| - framework_version): |
| 38 | +@pytest.mark.parametrize( |
| 39 | + 'instances, processes', [(1, 2), (2, 1), (2, 2), (5, 2)] |
| 40 | +) |
| 41 | +def test_distributed_training_horovod_cpu( |
| 42 | + instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version |
| 43 | +): |
| 44 | + _test_distributed_training_horovod( |
| 45 | + instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local' |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def _test_distributed_training_horovod( |
| 50 | + instances, processes, session, image_uri, tmpdir, framework_version, instance_type |
| 51 | +): |
39 | 52 | output_path = 'file://%s' % tmpdir
|
40 | 53 | estimator = TensorFlow(
|
41 | 54 | entry_point=os.path.join(RESOURCE_PATH, 'hvdbasic', 'train_hvd_basic.py'),
|
42 | 55 | role='SageMakerRole',
|
43 |
| - train_instance_type='local', |
44 |
| - sagemaker_session=sagemaker_local_session, |
| 56 | + train_instance_type=instance_type, |
| 57 | + sagemaker_session=session, |
45 | 58 | train_instance_count=instances,
|
46 | 59 | image_name=image_uri,
|
47 | 60 | output_path=output_path,
|
|
0 commit comments