Skip to content

Commit b160988

Browse files
authored
Update all MXNet integ tests to use the latest version (#589)
1 parent f822054 commit b160988

File tree

3 files changed

+7
-79
lines changed

3 files changed

+7
-79
lines changed

tests/data/mxnet_mnist/mnist_framework_mode.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

tests/integ/test_local_mode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,12 +286,13 @@ def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version):
286286
mx.delete_endpoint()
287287

288288

289-
def test_mxnet_local_data_local_script():
289+
def test_mxnet_local_data_local_script(mxnet_full_version):
290290
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
291-
script_path = os.path.join(data_path, 'mnist_framework_mode.py')
291+
script_path = os.path.join(data_path, 'mnist.py')
292292

293293
mx = MXNet(entry_point=script_path, role='SageMakerRole',
294294
train_instance_count=1, train_instance_type='local',
295+
framework_version=mxnet_full_version,
295296
sagemaker_session=LocalNoS3Session())
296297

297298
train_input = 'file://' + os.path.join(data_path, 'train')

tests/integ/test_tuner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,20 +404,20 @@ def test_stop_tuning_job(sagemaker_session):
404404

405405

406406
@pytest.mark.continuous_testing
407-
def test_tuning_mxnet(sagemaker_session):
407+
def test_tuning_mxnet(sagemaker_session, mxnet_full_version):
408408
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
409-
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist_framework_mode.py')
409+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
410410
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
411411

412412
estimator = MXNet(entry_point=script_path,
413413
role='SageMakerRole',
414414
py_version=PYTHON_VERSION,
415415
train_instance_count=1,
416416
train_instance_type='ml.m4.xlarge',
417-
framework_version='1.2.1',
417+
framework_version=mxnet_full_version,
418418
sagemaker_session=sagemaker_session)
419419

420-
hyperparameter_ranges = {'learning_rate': ContinuousParameter(0.01, 0.2)}
420+
hyperparameter_ranges = {'learning-rate': ContinuousParameter(0.01, 0.2)}
421421
objective_metric_name = 'Validation-accuracy'
422422
metric_definitions = [
423423
{'Name': 'Validation-accuracy', 'Regex': 'Validation-accuracy=([0-9\\.]+)'}]

0 commit comments

Comments
 (0)