Skip to content

Commit 3d1a4f7

Browse files
authored
fix: allow for inf spec and server override to be passed (#4769)
* fix: allow for just inf spec and server overide to pass * fix formatting
1 parent ce10e01 commit 3d1a4f7

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,8 +881,8 @@ def _build_for_model_server(self): # pylint: disable=R0911, R1710
881881
if self.model_metadata:
882882
mlflow_path = self.model_metadata.get(MLFLOW_MODEL_PATH)
883883

884-
if not self.model and not mlflow_path:
885-
raise ValueError("Missing required parameter `model` or 'ml_flow' path")
884+
if not self.model and not mlflow_path and not self.inference_spec:
885+
raise ValueError("Missing required parameter `model` or 'ml_flow' path or inf_spec")
886886

887887
if self.model_server == ModelServer.TORCHSERVE:
888888
return self._build_for_torchserve()

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_model_server_override_djl_without_model_or_mlflow(self, mock_serve_sett
147147
)
148148
self.assertRaisesRegex(
149149
Exception,
150-
"Missing required parameter `model` or 'ml_flow' path",
150+
"Missing required parameter `model` or 'ml_flow' path or inf_spec",
151151
builder.build,
152152
Mode.SAGEMAKER_ENDPOINT,
153153
mock_role_arn,
@@ -168,12 +168,26 @@ def test_model_server_override_torchserve_with_model(
168168

169169
mock_build_for_ts.assert_called_once()
170170

171+
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
172+
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_torchserve")
173+
def test_model_server_override_torchserve_with_inf_spec(
174+
self, mock_build_for_ts, mock_serve_settings
175+
):
176+
mock_setting_object = mock_serve_settings.return_value
177+
mock_setting_object.role_arn = mock_role_arn
178+
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
179+
180+
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE, inference_spec="some value")
181+
builder.build(sagemaker_session=mock_session)
182+
183+
mock_build_for_ts.assert_called_once()
184+
171185
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
172186
def test_model_server_override_torchserve_without_model_or_mlflow(self, mock_serve_settings):
173187
builder = ModelBuilder(model_server=ModelServer.TORCHSERVE)
174188
self.assertRaisesRegex(
175189
Exception,
176-
"Missing required parameter `model` or 'ml_flow' path",
190+
"Missing required parameter `model` or 'ml_flow' path or inf_spec",
177191
builder.build,
178192
Mode.SAGEMAKER_ENDPOINT,
179193
mock_role_arn,

0 commit comments

Comments
 (0)