Skip to content

Commit 801db44

Browse files
cj-zhangJoseph Zhanggwang111
authored
Fix: ModelBuilder deployment & optimization of JumpStart llama-3.1 models (#4937)
* Emit warning when cpu cores are requested with sharded model deployment. * Reformat sharded model validations. * fix pop on none error in jumpstart draft model flow * set lmi config on js model optimize * re-format lmi config switch * add e2e UT for lmi + .optimize() * add e2e UT for lmi + .optimize() no override * add deep UTs to catch regressions and test E2E fully and more practically * work around flake8 bug * flake8 workaround * fix flake8 syntax error in py38 --------- Co-authored-by: Joseph Zhang <[email protected]> Co-authored-by: Gary Wang 😤 <[email protected]>
1 parent 7c14046 commit 801db44

File tree

8 files changed

+503
-15
lines changed

8 files changed

+503
-15
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,20 @@ def deploy(
817817
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
818818
)
819819

820+
# No resources given to deploy() but present 'resources' key in deploy_kwargs means default
821+
# JumpStart resource requirements are being used
822+
if hasattr(self, "_is_sharded_model") and not resources and deploy_kwargs.resources:
823+
if (
824+
self._is_sharded_model
825+
and deploy_kwargs.resources.num_cpus
826+
and deploy_kwargs.resources.num_cpus > 0
827+
):
828+
JUMPSTART_LOGGER.warning(
829+
"NumOfCpuCoresRequired should be 0 for the best experience with SageMaker Fast "
830+
"Model Loading. Overriding the requested `num_cpus` to 0."
831+
)
832+
deploy_kwargs.resources.num_cpus = 0
833+
820834
self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
821835
self.additional_model_data_sources,
822836
deploy_kwargs.model_access_configs,

src/sagemaker/jumpstart/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,9 +1595,10 @@ def _add_model_access_configs_to_model_data_sources(
15951595
)
15961596
acked_model_data_sources.append(mutable_model_data_source)
15971597
else:
1598-
mutable_model_data_source.pop(
1599-
"HostingEulaKey"
1600-
) # pop when model access config is not applicable
1598+
if "HostingEulaKey" in mutable_model_data_source:
1599+
mutable_model_data_source.pop(
1600+
"HostingEulaKey"
1601+
) # pop when model access config is not applicable
16011602
acked_model_data_sources.append(mutable_model_data_source)
16021603
return acked_model_data_sources
16031604

src/sagemaker/model.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,18 +1600,25 @@ def deploy(
16001600
if self._base_name is not None:
16011601
self._base_name = "-".join((self._base_name, compiled_model_suffix))
16021602

1603-
if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
1604-
logging.warning(
1605-
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1606-
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
1607-
)
1608-
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
1603+
if self._is_sharded_model:
1604+
if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
1605+
logging.warning(
1606+
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1607+
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
1608+
)
1609+
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
16091610

1610-
if self._is_sharded_model and self._enable_network_isolation:
1611-
raise ValueError(
1612-
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1613-
"Loading of model requires network access."
1614-
)
1611+
if self._enable_network_isolation:
1612+
raise ValueError(
1613+
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1614+
"Loading of model requires network access."
1615+
)
1616+
1617+
if resources and resources.num_cpus and resources.num_cpus > 0:
1618+
logger.warning(
1619+
"NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker "
1620+
"Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`."
1621+
)
16151622

16161623
# Support multiple models on same endpoint
16171624
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
@@ -1655,7 +1662,7 @@ def deploy(
16551662
vpc_config=self.vpc_config,
16561663
enable_network_isolation=self._enable_network_isolation,
16571664
role=self.role,
1658-
live_logging=endpoint_logging,
1665+
live_logging=False, # TODO: enable when IC supports this
16591666
wait=wait,
16601667
)
16611668

src/sagemaker/serve/builder/model_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,10 @@ def _model_builder_optimize_wrapper(
13021302
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
13031303
if self._is_jumpstart_model_id():
13041304
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
1305+
if self.pysdk_model:
1306+
self.pysdk_model.set_deployment_config(
1307+
instance_type=instance_type, config_name="lmi"
1308+
)
13051309
input_args = self._optimize_for_jumpstart(
13061310
output_path=output_path,
13071311
instance_type=instance_type,
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
from unittest.mock import MagicMock, patch, ANY
15+
16+
from sagemaker.session import Session
17+
from sagemaker.serve.builder.model_builder import ModelBuilder
18+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
19+
from sagemaker.resource_requirements import ResourceRequirements
20+
21+
ROLE_NAME = "SageMakerRole"
22+
23+
24+
def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_expected(
25+
sagemaker_session,
26+
):
27+
with patch.object(
28+
Session, "create_model", return_value="mock_model"
29+
) as mock_create_model, patch.object(
30+
Session, "endpoint_from_production_variants"
31+
) as mock_endpoint_from_production_variants:
32+
iam_client = sagemaker_session.boto_session.client("iam")
33+
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
34+
35+
schema_builder = SchemaBuilder("test", "test")
36+
model_builder = ModelBuilder(
37+
model="meta-textgeneration-llama-3-1-8b-instruct",
38+
schema_builder=schema_builder,
39+
sagemaker_session=sagemaker_session,
40+
role_arn=role_arn,
41+
)
42+
43+
optimized_model = model_builder.optimize(
44+
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
45+
speculative_decoding_config={
46+
"ModelProvider": "JumpStart",
47+
"ModelID": "meta-textgeneration-llama-3-2-1b",
48+
"AcceptEula": True,
49+
},
50+
accept_eula=True,
51+
)
52+
53+
optimized_model.deploy()
54+
55+
mock_create_model.assert_called_once_with(
56+
name=ANY,
57+
role=ANY,
58+
container_defs={
59+
"Image": ANY,
60+
"Environment": {
61+
"SAGEMAKER_PROGRAM": "inference.py",
62+
"ENDPOINT_SERVER_TIMEOUT": "3600",
63+
"MODEL_CACHE_ROOT": "/opt/ml/model",
64+
"SAGEMAKER_ENV": "1",
65+
"HF_MODEL_ID": "/opt/ml/model",
66+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
67+
"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/",
68+
},
69+
"AdditionalModelDataSources": [
70+
{
71+
"ChannelName": "draft_model",
72+
"S3DataSource": {
73+
"S3Uri": ANY,
74+
"S3DataType": "S3Prefix",
75+
"CompressionType": "None",
76+
"ModelAccessConfig": {"AcceptEula": True},
77+
},
78+
}
79+
],
80+
"ModelDataSource": {
81+
"S3DataSource": {
82+
"S3Uri": ANY,
83+
"S3DataType": "S3Prefix",
84+
"CompressionType": "None",
85+
"ModelAccessConfig": {"AcceptEula": True},
86+
}
87+
},
88+
},
89+
vpc_config=None,
90+
enable_network_isolation=True,
91+
tags=ANY,
92+
)
93+
mock_endpoint_from_production_variants.assert_called_once()
94+
95+
96+
def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_expected(
97+
sagemaker_session,
98+
):
99+
with patch.object(
100+
Session,
101+
"wait_for_optimization_job",
102+
return_value={"OptimizationJobName": "mock_optimization_job"},
103+
), patch.object(
104+
Session, "create_model", return_value="mock_model"
105+
) as mock_create_model, patch.object(
106+
Session, "endpoint_from_production_variants", return_value="mock_endpoint_name"
107+
) as mock_endpoint_from_production_variants, patch.object(
108+
Session, "create_inference_component"
109+
) as mock_create_inference_component:
110+
iam_client = sagemaker_session.boto_session.client("iam")
111+
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
112+
113+
sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()
114+
115+
schema_builder = SchemaBuilder("test", "test")
116+
model_builder = ModelBuilder(
117+
model="meta-textgeneration-llama-3-1-8b-instruct",
118+
schema_builder=schema_builder,
119+
sagemaker_session=sagemaker_session,
120+
role_arn=role_arn,
121+
)
122+
123+
optimized_model = model_builder.optimize(
124+
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
125+
sharding_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "8"}},
126+
accept_eula=True,
127+
)
128+
129+
optimized_model.deploy(
130+
resources=ResourceRequirements(requests={"memory": 196608, "num_accelerators": 8})
131+
)
132+
133+
mock_create_model.assert_called_once_with(
134+
name=ANY,
135+
role=ANY,
136+
container_defs={
137+
"Image": ANY,
138+
"Environment": {
139+
"SAGEMAKER_PROGRAM": "inference.py",
140+
"ENDPOINT_SERVER_TIMEOUT": "3600",
141+
"MODEL_CACHE_ROOT": "/opt/ml/model",
142+
"SAGEMAKER_ENV": "1",
143+
"HF_MODEL_ID": "/opt/ml/model",
144+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
145+
"OPTION_TENSOR_PARALLEL_DEGREE": "8",
146+
},
147+
"ModelDataSource": {
148+
"S3DataSource": {
149+
"S3Uri": ANY,
150+
"S3DataType": "S3Prefix",
151+
"CompressionType": "None",
152+
"ModelAccessConfig": {"AcceptEula": True},
153+
}
154+
},
155+
},
156+
vpc_config=None,
157+
enable_network_isolation=False, # should be set to false
158+
tags=ANY,
159+
)
160+
mock_endpoint_from_production_variants.assert_called_once_with(
161+
name=ANY,
162+
production_variants=ANY,
163+
tags=ANY,
164+
kms_key=ANY,
165+
vpc_config=ANY,
166+
enable_network_isolation=False,
167+
role=ANY,
168+
live_logging=False, # this should be set to false for IC
169+
wait=True,
170+
)
171+
mock_create_inference_component.assert_called_once()
172+
173+
174+
def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are_expected(
175+
sagemaker_session,
176+
):
177+
with patch.object(
178+
Session,
179+
"wait_for_optimization_job",
180+
return_value={"OptimizationJobName": "mock_optimization_job"},
181+
), patch.object(
182+
Session, "create_model", return_value="mock_model"
183+
) as mock_create_model, patch.object(
184+
Session, "endpoint_from_production_variants", return_value="mock_endpoint_name"
185+
) as mock_endpoint_from_production_variants:
186+
iam_client = sagemaker_session.boto_session.client("iam")
187+
role_arn = iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"]
188+
189+
sagemaker_session.sagemaker_client.create_optimization_job = MagicMock()
190+
191+
schema_builder = SchemaBuilder("test", "test")
192+
model_builder = ModelBuilder(
193+
model="meta-textgeneration-llama-3-1-8b-instruct",
194+
schema_builder=schema_builder,
195+
sagemaker_session=sagemaker_session,
196+
role_arn=role_arn,
197+
)
198+
199+
optimized_model = model_builder.optimize(
200+
instance_type="ml.g5.xlarge", # set to small instance in case a network call is made
201+
quantization_config={
202+
"OverrideEnvironment": {
203+
"OPTION_QUANTIZE": "fp8",
204+
},
205+
},
206+
accept_eula=True,
207+
)
208+
209+
optimized_model.deploy()
210+
211+
mock_create_model.assert_called_once_with(
212+
name=ANY,
213+
role=ANY,
214+
container_defs={
215+
"Image": ANY,
216+
"Environment": {
217+
"SAGEMAKER_PROGRAM": "inference.py",
218+
"ENDPOINT_SERVER_TIMEOUT": "3600",
219+
"MODEL_CACHE_ROOT": "/opt/ml/model",
220+
"SAGEMAKER_ENV": "1",
221+
"HF_MODEL_ID": "/opt/ml/model",
222+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
223+
"OPTION_QUANTIZE": "fp8",
224+
},
225+
"ModelDataSource": {
226+
"S3DataSource": {
227+
"S3Uri": ANY,
228+
"S3DataType": "S3Prefix",
229+
"CompressionType": "None",
230+
"ModelAccessConfig": {"AcceptEula": True},
231+
}
232+
},
233+
},
234+
vpc_config=None,
235+
enable_network_isolation=True, # should be set to false
236+
tags=ANY,
237+
)
238+
mock_endpoint_from_production_variants.assert_called_once()

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,6 +2318,28 @@ def test_multiple_gated_additional_model_data_source_should_accept_both(self):
23182318
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
23192319
)
23202320

2321+
def test_gated_additional_model_data_source_already_accepted_with_no_hosting_eula_key_should_pass_through(
2322+
self,
2323+
):
2324+
mock_gated_deploy_config_additional_model_data_pre_accepted = [
2325+
{
2326+
"ChannelName": "draft_model",
2327+
"S3DataSource": {
2328+
"CompressionType": "None",
2329+
"S3DataType": "S3Prefix",
2330+
"S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/",
2331+
"ModelAccessConfig": {"AcceptEula": True},
2332+
},
2333+
}
2334+
]
2335+
2336+
utils._add_model_access_configs_to_model_data_sources(
2337+
model_data_sources=mock_gated_deploy_config_additional_model_data_pre_accepted,
2338+
model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=False)},
2339+
model_id=self.MOCK_GATED_MODEL_ID,
2340+
region=JUMPSTART_DEFAULT_REGION_NAME,
2341+
)
2342+
23212343
# Mixed Positive Cases
23222344

23232345
def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other(

0 commit comments

Comments
 (0)