Skip to content

Commit 71468e5

Browse files
authored
Fetch HF metadata only when explicit type is not selected (#4407)
1 parent 8eea80a commit 71468e5

File tree

4 files changed

+14
-138
lines changed

4 files changed

+14
-138
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -579,19 +579,19 @@ def build(
579579

580580
self.serve_settings = self._get_serve_setting()
581581

582-
hf_model_md = get_huggingface_model_metadata(
583-
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
584-
)
585-
586582
if isinstance(self.model, str):
587583
if self._is_jumpstart_model_id():
588584
return self._build_for_jumpstart()
589-
if self._is_djl():
585+
if self._is_djl(): # pylint: disable=R1705
590586
return self._build_for_djl()
591-
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
592-
return self._build_for_tgi()
593587
else:
594-
return self._build_for_transformers()
588+
hf_model_md = get_huggingface_model_metadata(
589+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
590+
)
591+
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
592+
return self._build_for_tgi()
593+
else:
594+
return self._build_for_transformers()
595595

596596
self._build_validations()
597597

tests/integ/sagemaker/serve/test_serve_js_happy.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
from unittest.mock import patch, Mock
1716
from sagemaker.serve.builder.model_builder import ModelBuilder
1817
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1918
from tests.integ.sagemaker.serve.constants import (
@@ -33,7 +32,6 @@
3332
]
3433
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
3534
ROLE_NAME = "SageMakerRole"
36-
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
3735

3836

3937
@pytest.fixture
@@ -47,23 +45,14 @@ def happy_model_builder(sagemaker_session):
4745
)
4846

4947

50-
@patch("sagemaker.huggingface.llm_utils.urllib")
51-
@patch("sagemaker.huggingface.llm_utils.json")
5248
@pytest.mark.skipif(
5349
PYTHON_VERSION_IS_NOT_310,
5450
reason="The goal of these test are to test the serving components of our feature",
5551
)
5652
@pytest.mark.slow_test
57-
def test_happy_tgi_sagemaker_endpoint(
58-
mock_urllib, mock_json, happy_model_builder, gpu_instance_type
59-
):
53+
def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
6054
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
6155
caught_ex = None
62-
63-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
64-
mock_hf_model_metadata_url = Mock()
65-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
66-
6756
model = happy_model_builder.build()
6857

6958
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):

tests/integ/sagemaker/serve/test_serve_pt_happy.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import io
2020
import numpy as np
2121

22-
from unittest.mock import patch, Mock
2322
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
2423
from sagemaker.serve.builder.schema_builder import SchemaBuilder, CustomPayloadTranslator
2524
from sagemaker.serve.spec.inference_spec import InferenceSpec
@@ -38,7 +37,6 @@
3837
logger = logging.getLogger(__name__)
3938

4039
ROLE_NAME = "SageMakerRole"
41-
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
4240

4341

4442
@pytest.fixture
@@ -183,8 +181,6 @@ def model_builder(request):
183181
# ), f"{caught_ex} was thrown when running pytorch squeezenet local container test"
184182

185183

186-
@patch("sagemaker.huggingface.llm_utils.urllib")
187-
@patch("sagemaker.huggingface.llm_utils.json")
188184
@pytest.mark.skipif(
189185
PYTHON_VERSION_IS_NOT_310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE,
190186
reason="The goal of these test are to test the serving components of our feature",
@@ -194,17 +190,12 @@ def model_builder(request):
194190
)
195191
@pytest.mark.slow_test
196192
def test_happy_pytorch_sagemaker_endpoint(
197-
mock_urllib,
198-
mock_json,
199193
sagemaker_session,
200194
model_builder,
201195
cpu_instance_type,
202196
test_image,
203197
):
204198
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
205-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
206-
mock_hf_model_metadata_url = Mock()
207-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
208199
caught_ex = None
209200

210201
iam_client = sagemaker_session.boto_session.client("iam")

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 5 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
mock_s3_model_data_url = "sample s3 data url"
4343
mock_secret_key = "mock_secret_key"
4444
mock_instance_type = "mock instance type"
45-
MOCK_HF_MODEL_METADATA_JSON = {"mock_key": "mock_value"}
4645

4746
supported_model_server = {
4847
ModelServer.TORCHSERVE,
@@ -55,15 +54,7 @@
5554

5655
class TestModelBuilder(unittest.TestCase):
5756
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
58-
@patch("sagemaker.huggingface.llm_utils.urllib")
59-
@patch("sagemaker.huggingface.llm_utils.json")
60-
def test_validation_in_progress_mode_not_supported(
61-
self, mock_serveSettings, mock_urllib, mock_json
62-
):
63-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
64-
mock_hf_model_metadata_url = Mock()
65-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
66-
57+
def test_validation_in_progress_mode_not_supported(self, mock_serveSettings):
6758
builder = ModelBuilder()
6859
self.assertRaisesRegex(
6960
Exception,
@@ -75,15 +66,7 @@ def test_validation_in_progress_mode_not_supported(
7566
)
7667

7768
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
78-
@patch("sagemaker.huggingface.llm_utils.urllib")
79-
@patch("sagemaker.huggingface.llm_utils.json")
80-
def test_validation_cannot_set_both_model_and_inference_spec(
81-
self, mock_serveSettings, mock_urllib, mock_json
82-
):
83-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
84-
mock_hf_model_metadata_url = Mock()
85-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
86-
69+
def test_validation_cannot_set_both_model_and_inference_spec(self, mock_serveSettings):
8770
builder = ModelBuilder(inference_spec="some value", model=Mock(spec=object))
8871
self.assertRaisesRegex(
8972
Exception,
@@ -95,15 +78,7 @@ def test_validation_cannot_set_both_model_and_inference_spec(
9578
)
9679

9780
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
98-
@patch("sagemaker.huggingface.llm_utils.urllib")
99-
@patch("sagemaker.huggingface.llm_utils.json")
100-
def test_validation_unsupported_model_server_type(
101-
self, mock_serveSettings, mock_urllib, mock_json
102-
):
103-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
104-
mock_hf_model_metadata_url = Mock()
105-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
106-
81+
def test_validation_unsupported_model_server_type(self, mock_serveSettings):
10782
builder = ModelBuilder(model_server="invalid_model_server")
10883
self.assertRaisesRegex(
10984
Exception,
@@ -116,15 +91,7 @@ def test_validation_unsupported_model_server_type(
11691
)
11792

11893
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
119-
@patch("sagemaker.huggingface.llm_utils.urllib")
120-
@patch("sagemaker.huggingface.llm_utils.json")
121-
def test_validation_model_server_not_set_with_image_uri(
122-
self, mock_serveSettings, mock_urllib, mock_json
123-
):
124-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
125-
mock_hf_model_metadata_url = Mock()
126-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
127-
94+
def test_validation_model_server_not_set_with_image_uri(self, mock_serveSettings):
12895
builder = ModelBuilder(image_uri="image_uri")
12996
self.assertRaisesRegex(
13097
Exception,
@@ -137,15 +104,9 @@ def test_validation_model_server_not_set_with_image_uri(
137104
)
138105

139106
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
140-
@patch("sagemaker.huggingface.llm_utils.urllib")
141-
@patch("sagemaker.huggingface.llm_utils.json")
142107
def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set(
143-
self, mock_serveSettings, mock_urllib, mock_json
108+
self, mock_serveSettings
144109
):
145-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
146-
mock_hf_model_metadata_url = Mock()
147-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
148-
149110
builder = ModelBuilder(inference_spec=None, model=None)
150111
self.assertRaisesRegex(
151112
Exception,
@@ -165,12 +126,8 @@ def test_save_model_throw_exception_when_none_of_model_and_inference_spec_is_set
165126
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
166127
@patch("sagemaker.serve.builder.model_builder.Model")
167128
@patch("os.path.exists")
168-
@patch("sagemaker.huggingface.llm_utils.urllib")
169-
@patch("sagemaker.huggingface.llm_utils.json")
170129
def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
171130
self,
172-
mock_urllib,
173-
mock_json,
174131
mock_path_exists,
175132
mock_sdk_model,
176133
mock_sageMakerEndpointMode,
@@ -189,10 +146,6 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
189146
else None
190147
)
191148

192-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
193-
mock_hf_model_metadata_url = Mock()
194-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
195-
196149
mock_detect_fw_version.return_value = framework, version
197150

198151
mock_prepare_for_torchserve.side_effect = (
@@ -273,12 +226,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_byoc(
273226
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
274227
@patch("sagemaker.serve.builder.model_builder.Model")
275228
@patch("os.path.exists")
276-
@patch("sagemaker.huggingface.llm_utils.urllib")
277-
@patch("sagemaker.huggingface.llm_utils.json")
278229
def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
279230
self,
280-
mock_urllib,
281-
mock_json,
282231
mock_path_exists,
283232
mock_sdk_model,
284233
mock_sageMakerEndpointMode,
@@ -296,11 +245,6 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
296245
and instance_type == "ml.c5.xlarge"
297246
else None
298247
)
299-
300-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
301-
mock_hf_model_metadata_url = Mock()
302-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
303-
304248
mock_detect_fw_version.return_value = framework, version
305249

306250
mock_prepare_for_torchserve.side_effect = (
@@ -381,12 +325,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_1p_dlc_as_byoc(
381325
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
382326
@patch("sagemaker.serve.builder.model_builder.Model")
383327
@patch("os.path.exists")
384-
@patch("sagemaker.huggingface.llm_utils.urllib")
385-
@patch("sagemaker.huggingface.llm_utils.json")
386328
def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
387329
self,
388-
mock_urllib,
389-
mock_json,
390330
mock_path_exists,
391331
mock_sdk_model,
392332
mock_sageMakerEndpointMode,
@@ -402,10 +342,6 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
402342
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
403343
)
404344

405-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
406-
mock_hf_model_metadata_url = Mock()
407-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
408-
409345
mock_detect_fw_version.return_value = framework, version
410346

411347
mock_detect_container.side_effect = (
@@ -490,12 +426,8 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_and_inference_spec(
490426
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
491427
@patch("sagemaker.serve.builder.model_builder.Model")
492428
@patch("os.path.exists")
493-
@patch("sagemaker.huggingface.llm_utils.urllib")
494-
@patch("sagemaker.huggingface.llm_utils.json")
495429
def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
496430
self,
497-
mock_urllib,
498-
mock_json,
499431
mock_path_exists,
500432
mock_sdk_model,
501433
mock_sageMakerEndpointMode,
@@ -514,10 +446,6 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
514446
else None
515447
)
516448

517-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
518-
mock_hf_model_metadata_url = Mock()
519-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
520-
521449
mock_detect_fw_version.return_value = framework, version
522450

523451
mock_prepare_for_torchserve.side_effect = (
@@ -601,12 +529,8 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_model(
601529
@patch("sagemaker.serve.builder.model_builder.SageMakerEndpointMode")
602530
@patch("sagemaker.serve.builder.model_builder.Model")
603531
@patch("os.path.exists")
604-
@patch("sagemaker.huggingface.llm_utils.urllib")
605-
@patch("sagemaker.huggingface.llm_utils.json")
606532
def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
607533
self,
608-
mock_urllib,
609-
mock_json,
610534
mock_path_exists,
611535
mock_sdk_model,
612536
mock_sageMakerEndpointMode,
@@ -626,10 +550,6 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
626550
else None
627551
)
628552

629-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
630-
mock_hf_model_metadata_url = Mock()
631-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
632-
633553
mock_detect_fw_version.return_value = "xgboost", version
634554

635555
mock_prepare_for_torchserve.side_effect = (
@@ -714,12 +634,8 @@ def test_build_happy_path_with_sagemakerEndpoint_mode_and_xgboost_model(
714634
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
715635
@patch("sagemaker.serve.builder.model_builder.Model")
716636
@patch("os.path.exists")
717-
@patch("sagemaker.huggingface.llm_utils.urllib")
718-
@patch("sagemaker.huggingface.llm_utils.json")
719637
def test_build_happy_path_with_local_container_mode(
720638
self,
721-
mock_urllib,
722-
mock_json,
723639
mock_path_exists,
724640
mock_sdk_model,
725641
mock_localContainerMode,
@@ -734,10 +650,6 @@ def test_build_happy_path_with_local_container_mode(
734650
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
735651
)
736652

737-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
738-
mock_hf_model_metadata_url = Mock()
739-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
740-
741653
mock_detect_container.side_effect = (
742654
lambda model, region, instance_type: mock_image_uri
743655
if model == mock_native_model
@@ -816,12 +728,8 @@ def test_build_happy_path_with_local_container_mode(
816728
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
817729
@patch("sagemaker.serve.builder.model_builder.Model")
818730
@patch("os.path.exists")
819-
@patch("sagemaker.huggingface.llm_utils.urllib")
820-
@patch("sagemaker.huggingface.llm_utils.json")
821731
def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mode(
822732
self,
823-
mock_urllib,
824-
mock_json,
825733
mock_path_exists,
826734
mock_sdk_model,
827735
mock_localContainerMode,
@@ -838,10 +746,6 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
838746
lambda model_path: mock_native_model if model_path == MODEL_PATH else None
839747
)
840748

841-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
842-
mock_hf_model_metadata_url = Mock()
843-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
844-
845749
mock_detect_fw_version.return_value = framework, version
846750

847751
mock_detect_container.side_effect = (
@@ -965,12 +869,8 @@ def test_build_happy_path_with_localContainer_mode_overwritten_with_sagemaker_mo
965869
@patch("sagemaker.serve.builder.model_builder.LocalContainerMode")
966870
@patch("sagemaker.serve.builder.model_builder.Model")
967871
@patch("os.path.exists")
968-
@patch("sagemaker.huggingface.llm_utils.urllib")
969-
@patch("sagemaker.huggingface.llm_utils.json")
970872
def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_container(
971873
self,
972-
mock_urllib,
973-
mock_json,
974874
mock_path_exists,
975875
mock_sdk_model,
976876
mock_localContainerMode,
@@ -984,10 +884,6 @@ def test_build_happy_path_with_sagemaker_endpoint_mode_overwritten_with_local_co
984884
# setup mocks
985885
mock_detect_fw_version.return_value = framework, version
986886

987-
mock_json.load.return_value = MOCK_HF_MODEL_METADATA_JSON
988-
mock_hf_model_metadata_url = Mock()
989-
mock_urllib.request.Request.side_effect = mock_hf_model_metadata_url
990-
991887
mock_detect_container.side_effect = (
992888
lambda model, region, instance_type: mock_image_uri
993889
if model == mock_fw_model

0 commit comments

Comments
 (0)