Skip to content

Commit 1be4460

Browse files
fix: breaking deviations in _create_sagemaker_model call (#3919)
1 parent a5719f8 commit 1be4460

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

src/sagemaker/model.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
import logging
1919
import os
20+
import re
2021
import copy
2122
from typing import List, Dict, Optional, Union
2223

@@ -1662,6 +1663,10 @@ def __init__(
16621663
)
16631664

16641665

1666+
# works for MODEL_PACKAGE_ARN with or without version info.
1667+
MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
1668+
1669+
16651670
class ModelPackage(Model):
16661671
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
16671672

@@ -1769,14 +1774,19 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
17691774
model_package_name = self._created_model_package_name
17701775
else:
17711776
# When a ModelPackageArn is provided we just create the Model
1772-
model_package_name = self.model_package_arn
1777+
match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn)
1778+
if match:
1779+
model_package_name = match.group(3)
1780+
else:
1781+
# model_package_arn can be just the name if your account owns the Model Package
1782+
model_package_name = self.model_package_arn
17731783

1774-
container_def = {"ModelPackageName": model_package_name}
1784+
container_def = {"ModelPackageName": self.model_package_arn}
17751785

17761786
if self.env != {}:
17771787
container_def["Environment"] = self.env
17781788

1779-
self._ensure_base_name_if_needed(model_package_name.split("/")[-1])
1789+
self._ensure_base_name_if_needed(model_package_name)
17801790
self._set_model_name_if_needed()
17811791

17821792
self.sagemaker_session.create_model(
@@ -1785,6 +1795,7 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
17851795
container_def,
17861796
vpc_config=self.vpc_config,
17871797
enable_network_isolation=self.enable_network_isolation(),
1798+
tags=kwargs.get("tags"),
17881799
)
17891800

17901801
def _ensure_base_name_if_needed(self, base_name):

tests/unit/sagemaker/model/test_model_package.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,42 @@ def test_create_sagemaker_model_uses_model_name(name_from_base, sagemaker_sessio
115115
{"ModelPackageName": model_package_name},
116116
vpc_config=None,
117117
enable_network_isolation=False,
118+
tags=None,
119+
)
120+
121+
122+
@pytest.mark.parametrize(
123+
"model_package_arn",
124+
[
125+
"arn:aws:sagemaker:us-east-2:123:model-package/my-model-package-arn",
126+
"arn:aws:sagemaker:us-east-2:123:model-package/my-model-package-arn/12",
127+
],
128+
)
129+
@patch("sagemaker.utils.name_from_base")
130+
def test_create_sagemaker_model_uses_model_package_arn(
131+
name_from_base, sagemaker_session, model_package_arn
132+
):
133+
model_name = "my-model"
134+
135+
model_package = ModelPackage(
136+
role="role",
137+
name=model_name,
138+
model_package_arn=model_package_arn,
139+
sagemaker_session=sagemaker_session,
140+
)
141+
142+
model_package._create_sagemaker_model()
143+
144+
assert model_name == model_package.name
145+
name_from_base.assert_not_called()
146+
147+
sagemaker_session.create_model.assert_called_with(
148+
model_name,
149+
"role",
150+
{"ModelPackageName": model_package_arn},
151+
vpc_config=None,
152+
enable_network_isolation=False,
153+
tags=None,
118154
)
119155

120156

@@ -141,6 +177,35 @@ def test_create_sagemaker_model_include_environment_variable(sagemaker_session):
141177
{"ModelPackageName": model_package_name, "Environment": environment},
142178
vpc_config=None,
143179
enable_network_isolation=False,
180+
tags=None,
181+
)
182+
183+
184+
def test_create_sagemaker_model_include_tags(sagemaker_session):
185+
model_name = "my-model"
186+
model_package_name = "my-model-package"
187+
env_key = "env_key"
188+
env_value = "env_value"
189+
environment = {env_key: env_value}
190+
tags = {"Key": "foo", "Value": "bar"}
191+
192+
model_package = ModelPackage(
193+
role="role",
194+
name=model_name,
195+
model_package_arn=model_package_name,
196+
env=environment,
197+
sagemaker_session=sagemaker_session,
198+
)
199+
200+
model_package._create_sagemaker_model(tags=tags)
201+
202+
sagemaker_session.create_model.assert_called_with(
203+
model_name,
204+
"role",
205+
{"ModelPackageName": model_package_name, "Environment": environment},
206+
vpc_config=None,
207+
enable_network_isolation=False,
208+
tags=tags,
144209
)
145210

146211

0 commit comments

Comments
 (0)