Skip to content

Commit 5cfa6b3

Browse files
authored
change: handle image_uri rename for estimators and models in v2 migration tool (aws#1675)
1 parent 9a0f8ac commit 5cfa6b3

File tree

7 files changed

+193
-13
lines changed

7 files changed

+193
-13
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from sagemaker.cli.compatibility.v2 import modifiers
1919

2020
FUNCTION_CALL_MODIFIERS = [
21+
modifiers.renamed_params.EstimatorImageURIRenamer(),
22+
modifiers.renamed_params.ModelImageURIRenamer(),
2123
modifiers.framework_version.FrameworkVersionEnforcer(),
2224
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
2325
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2020

2121
FRAMEWORK_ARG = "framework_version"
22+
IMAGE_ARG = "image_uri"
2223
PY_ARG = "py_version"
2324

2425
FRAMEWORK_DEFAULTS = {
@@ -70,11 +71,8 @@ def node_should_be_modified(self, node):
7071
bool: If the ``ast.Call`` is instantiating a framework class that
7172
should specify ``framework_version``, but doesn't.
7273
"""
73-
if matching.matches_any(node, ESTIMATORS):
74-
return _version_args_needed(node, "image_name")
75-
76-
if matching.matches_any(node, MODELS):
77-
return _version_args_needed(node, "image")
74+
if matching.matches_any(node, ESTIMATORS) or matching.matches_any(node, MODELS):
75+
return _version_args_needed(node)
7876

7977
return False
8078

@@ -169,13 +167,13 @@ def _framework_from_node(node):
169167
return framework, is_model
170168

171169

172-
def _version_args_needed(node, image_arg):
170+
def _version_args_needed(node):
173171
"""Determines if image_arg or version_arg was supplied
174172
175173
Applies similar logic as ``validate_version_or_image_args``
176174
"""
177175
# if image_arg is present, no need to supply version arguments
178-
if matching.has_arg(node, image_arg):
176+
if matching.has_arg(node, IMAGE_ARG):
179177
return False
180178

181179
# if framework_version is None, need args

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

+62
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,65 @@ def node_should_be_modified(self, node):
147147
return False
148148

149149
return super(S3SessionRenamer, self).node_should_be_modified(node)
150+
151+
152+
class EstimatorImageURIRenamer(ParamRenamer):
153+
"""A class to rename the ``image_name`` attribute to ``image_uri`` in estimators."""
154+
155+
@property
156+
def calls_to_modify(self):
157+
"""A dictionary mapping estimators with the ``image_name`` attribute to their
158+
respective namespaces.
159+
"""
160+
return {
161+
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
162+
"Estimator": ("sagemaker.estimator",),
163+
"Framework": ("sagemaker.estimator",),
164+
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
165+
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
166+
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
167+
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
168+
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
169+
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
170+
}
171+
172+
@property
173+
def old_param_name(self):
174+
"""The previous name for the image URI argument."""
175+
return "image_name"
176+
177+
@property
178+
def new_param_name(self):
179+
"""The new name for the image URI argument."""
180+
return "image_uri"
181+
182+
183+
class ModelImageURIRenamer(ParamRenamer):
184+
"""A class to rename the ``image`` attribute to ``image_uri`` in models."""
185+
186+
@property
187+
def calls_to_modify(self):
188+
"""A dictionary mapping models with the ``image`` attribute to their
189+
respective namespaces.
190+
"""
191+
return {
192+
"ChainerModel": ("sagemaker.chainer", "sagemaker.chainer.model"),
193+
"Model": ("sagemaker.model",),
194+
"MultiDataModel": ("sagemaker.multidatamodel",),
195+
"FrameworkModel": ("sagemaker.model",),
196+
"MXNetModel": ("sagemaker.mxnet", "sagemaker.mxnet.model"),
197+
"PyTorchModel": ("sagemaker.pytorch", "sagemaker.pytorch.model"),
198+
"SKLearnModel": ("sagemaker.sklearn", "sagemaker.sklearn.model"),
199+
"TensorFlowModel": ("sagemaker.tensorflow", "sagemaker.tensorflow.model"),
200+
"XGBoostModel": ("sagemaker.xgboost", "sagemaker.xgboost.model"),
201+
}
202+
203+
@property
204+
def old_param_name(self):
205+
"""The previous name for the image URI argument."""
206+
return "image"
207+
208+
@property
209+
def new_param_name(self):
210+
"""The new name for the image URI argument."""
211+
return "image_uri"

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,15 @@ def modify_node(self, node):
116116
hp_key = self._hyperparameter_key_for_param(kw.arg)
117117
additional_hps[hp_key] = kw.value
118118
kw_to_remove.append(kw)
119-
if kw.arg == "image_name":
119+
if kw.arg == "image_uri":
120120
add_image_uri = False
121121

122122
self._remove_keywords(node, kw_to_remove)
123123
self._add_updated_hyperparameters(node, base_hps, additional_hps)
124124

125125
if add_image_uri:
126126
image_uri = self._image_uri_from_args(node.keywords)
127-
node.keywords.append(ast.keyword(arg="image_name", value=ast.Str(s=image_uri)))
127+
node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri)))
128128

129129
node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False)))
130130

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _templates(self, model=False):
5757
def _frameworks(self, versions=False, image=False):
5858
keywords = dict()
5959
if image:
60-
keywords["image_name"] = "my:image"
60+
keywords["image_uri"] = "my:image"
6161
if versions:
6262
keywords["framework_version"] = self.framework_version
6363
keywords["py_version"] = self.py_version
@@ -66,7 +66,7 @@ def _frameworks(self, versions=False, image=False):
6666
def _models(self, versions=False, image=False):
6767
keywords = dict()
6868
if image:
69-
keywords["image"] = "my:image"
69+
keywords["image_uri"] = "my:image"
7070
if versions:
7171
keywords["framework_version"] = self.framework_version
7272
if self.py_version_for_model:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import renamed_params
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
19+
20+
ESTIMATORS = {
21+
"Chainer": ("sagemaker.chainer", "sagemaker.chainer.estimator"),
22+
"Estimator": ("sagemaker.estimator",),
23+
"Framework": ("sagemaker.estimator",),
24+
"MXNet": ("sagemaker.mxnet", "sagemaker.mxnet.estimator"),
25+
"PyTorch": ("sagemaker.pytorch", "sagemaker.pytorch.estimator"),
26+
"RLEstimator": ("sagemaker.rl", "sagemaker.rl.estimator"),
27+
"SKLearn": ("sagemaker.sklearn", "sagemaker.sklearn.estimator"),
28+
"TensorFlow": ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator"),
29+
"XGBoost": ("sagemaker.xgboost", "sagemaker.xgboost.estimator"),
30+
}
31+
32+
MODELS = {
33+
"ChainerModel": ("sagemaker.chainer", "sagemaker.chainer.model"),
34+
"Model": ("sagemaker.model",),
35+
"MultiDataModel": ("sagemaker.multidatamodel",),
36+
"FrameworkModel": ("sagemaker.model",),
37+
"MXNetModel": ("sagemaker.mxnet", "sagemaker.mxnet.model"),
38+
"PyTorchModel": ("sagemaker.pytorch", "sagemaker.pytorch.model"),
39+
"SKLearnModel": ("sagemaker.sklearn", "sagemaker.sklearn.model"),
40+
"TensorFlowModel": ("sagemaker.tensorflow", "sagemaker.tensorflow.model"),
41+
"XGBoostModel": ("sagemaker.xgboost", "sagemaker.xgboost.model"),
42+
}
43+
44+
45+
def test_estimator_node_should_be_modified():
46+
modifier = renamed_params.EstimatorImageURIRenamer()
47+
48+
for estimator, namespaces in ESTIMATORS.items():
49+
call = "{}(image_name='my-image:latest')".format(estimator)
50+
assert modifier.node_should_be_modified(ast_call(call))
51+
52+
for namespace in namespaces:
53+
call = "{}.{}(image_name='my-image:latest')".format(namespace, estimator)
54+
assert modifier.node_should_be_modified(ast_call(call))
55+
56+
57+
def test_estimator_node_should_be_modified_no_distribution():
58+
modifier = renamed_params.EstimatorImageURIRenamer()
59+
60+
for estimator, namespaces in ESTIMATORS.items():
61+
call = "{}()".format(estimator)
62+
assert not modifier.node_should_be_modified(ast_call(call))
63+
64+
for namespace in namespaces:
65+
call = "{}.{}()".format(namespace, estimator)
66+
assert not modifier.node_should_be_modified(ast_call(call))
67+
68+
69+
def test_estimator_node_should_be_modified_random_function_call():
70+
modifier = renamed_params.EstimatorImageURIRenamer()
71+
assert not modifier.node_should_be_modified(ast_call("Session()"))
72+
73+
74+
def test_estimator_modify_node():
75+
node = ast_call("TensorFlow(image_name=my_image)")
76+
modifier = renamed_params.EstimatorImageURIRenamer()
77+
modifier.modify_node(node)
78+
79+
expected = "TensorFlow(image_uri=my_image)"
80+
assert expected == pasta.dump(node)
81+
82+
83+
def test_model_node_should_be_modified():
84+
modifier = renamed_params.ModelImageURIRenamer()
85+
86+
for model, namespaces in MODELS.items():
87+
call = "{}(image='my-image:latest')".format(model)
88+
assert modifier.node_should_be_modified(ast_call(call))
89+
90+
for namespace in namespaces:
91+
call = "{}.{}(image='my-image:latest')".format(namespace, model)
92+
assert modifier.node_should_be_modified(ast_call(call))
93+
94+
95+
def test_model_node_should_be_modified_no_distribution():
96+
modifier = renamed_params.ModelImageURIRenamer()
97+
98+
for model, namespaces in MODELS.items():
99+
call = "{}()".format(model)
100+
assert not modifier.node_should_be_modified(ast_call(call))
101+
102+
for namespace in namespaces:
103+
call = "{}.{}()".format(namespace, model)
104+
assert not modifier.node_should_be_modified(ast_call(call))
105+
106+
107+
def test_model_node_should_be_modified_random_function_call():
108+
modifier = renamed_params.ModelImageURIRenamer()
109+
assert not modifier.node_should_be_modified(ast_call("Session()"))
110+
111+
112+
def test_model_modify_node():
113+
node = ast_call("TensorFlowModel(image=my_image)")
114+
modifier = renamed_params.ModelImageURIRenamer()
115+
modifier.modify_node(node)
116+
117+
expected = "TensorFlowModel(image_uri=my_image)"
118+
assert expected == pasta.dump(node)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session
9090
node = ast_call(constructor)
9191
modifier.modify_node(node)
9292

93-
assert "TensorFlow(image_name='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node)
93+
assert "TensorFlow(image_uri='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node)
9494
create_image_uri.assert_called_with(
9595
REGION_NAME, "tensorflow", "ml.m4.xlarge", "1.11.0", "py2"
9696
)
@@ -111,7 +111,7 @@ def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
111111

112112
expected_string = (
113113
"TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', "
114-
"image_name='{}', model_dir=False)".format(IMAGE_URI)
114+
"image_uri='{}', model_dir=False)".format(IMAGE_URI)
115115
)
116116
assert expected_string == pasta.dump(node)
117117

0 commit comments

Comments
 (0)