Skip to content

Commit c9a0875

Browse files
author
Joseph Zhang
committed
Fix quantization/compilation config handling for optimize().
1 parent 828ad60 commit c9a0875

File tree

4 files changed

+47
-27
lines changed

4 files changed

+47
-27
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -727,8 +727,8 @@ def _optimize_for_jumpstart(
727727
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
728728

729729
# optimization_config can contain configs for both quantization and compilation
730-
optimization_config, override_env = _extract_optimization_config_and_env(
731-
quantization_config, compilation_config
730+
optimization_config, quantization_override_env, compilation_override_env = (
731+
_extract_optimization_config_and_env(quantization_config, compilation_config)
732732
)
733733
if (
734734
not optimization_config or not optimization_config.get("ModelCompilationConfig")
@@ -738,13 +738,16 @@ def _optimize_for_jumpstart(
738738
optimization_config = {}
739739

740740
# Fallback to default if override_env is None or empty
741-
if not override_env:
742-
override_env = pysdk_model_env_vars
741+
if not compilation_override_env:
742+
compilation_override_env = pysdk_model_env_vars
743743

744744
# Update optimization_config with ModelCompilationConfig
745-
optimization_config["ModelCompilationConfig"] = {
746-
"OverrideEnvironment": override_env,
747-
}
745+
override_compilation_config = (
746+
{"OverrideEnvironment": compilation_override_env}
747+
if compilation_override_env
748+
else {}
749+
)
750+
optimization_config["ModelCompilationConfig"] = override_compilation_config
748751

749752
if speculative_decoding_config:
750753
self._set_additional_model_source(speculative_decoding_config)
@@ -798,7 +801,13 @@ def _optimize_for_jumpstart(
798801
"AcceptEula": True
799802
}
800803

801-
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
804+
optimization_env_vars = _update_environment_variables(
805+
optimization_env_vars,
806+
{
807+
**(quantization_override_env or {}),
808+
**(compilation_override_env or {}),
809+
},
810+
)
802811
if optimization_env_vars:
803812
self.pysdk_model.env.update(optimization_env_vars)
804813
if quantization_config or is_compilation:

src/sagemaker/serve/builder/model_builder.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,13 +1361,18 @@ def _optimize_for_hf(
13611361
model_source = _generate_model_source(self.pysdk_model.model_data, False)
13621362
create_optimization_job_args["ModelSource"] = model_source
13631363

1364-
optimization_config, override_env = _extract_optimization_config_and_env(
1365-
quantization_config, compilation_config
1364+
optimization_config, quantization_override_env, compilation_override_env = (
1365+
_extract_optimization_config_and_env(quantization_config, compilation_config)
13661366
)
13671367
create_optimization_job_args["OptimizationConfigs"] = [
13681368
{k: v} for k, v in optimization_config.items()
13691369
]
1370-
self.pysdk_model.env.update(override_env)
1370+
self.pysdk_model.env.update(
1371+
{
1372+
**(quantization_override_env or {}),
1373+
**(compilation_override_env or {}),
1374+
}
1375+
)
13711376

13721377
output_config = {"S3OutputLocation": output_path}
13731378
if kms_key:

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
260260

261261
def _extract_optimization_config_and_env(
262262
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
263-
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
263+
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
264264
"""Extracts optimization config and environment variables.
265265
266266
Args:
@@ -272,25 +272,24 @@ def _extract_optimization_config_and_env(
272272
The optimization config and environment variables.
273273
"""
274274
optimization_config = {}
275-
quantization_override_env = {}
276-
compilation_override_env = {}
275+
quantization_override_env = (
276+
quantization_config.get("OverrideEnvironment", {}) if quantization_config else None
277+
)
278+
compilation_override_env = (
279+
compilation_config.get("OverrideEnvironment", {}) if compilation_config else None
280+
)
277281

278-
if quantization_config:
282+
if quantization_config is not None:
279283
optimization_config["ModelQuantizationConfig"] = quantization_config
280-
quantization_override_env = quantization_config.get("OverrideEnvironment")
281284

282-
if compilation_config:
285+
if compilation_config is not None:
283286
optimization_config["ModelCompilationConfig"] = compilation_config
284-
compilation_override_env = compilation_config.get("OverrideEnvironment")
285287

286288
# Return both dicts and environment variable if either is present
287289
if optimization_config:
288-
return optimization_config, {
289-
**(quantization_override_env or {}),
290-
**(compilation_override_env or {}),
291-
}
290+
return optimization_config, quantization_override_env, compilation_override_env
292291

293-
return None, None
292+
return None, None, None
294293

295294

296295
def _custom_speculative_decoding(

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_is_s3_uri(s3_uri, expected):
261261

262262

263263
@pytest.mark.parametrize(
264-
"quantization_config, compilation_config, expected_config, expected_env",
264+
"quantization_config, compilation_config, expected_config, expected_quant_env, expected_compilation_env",
265265
[
266266
(
267267
None,
@@ -277,6 +277,7 @@ def test_is_s3_uri(s3_uri, expected):
277277
}
278278
},
279279
},
280+
None,
280281
{
281282
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
282283
},
@@ -298,16 +299,22 @@ def test_is_s3_uri(s3_uri, expected):
298299
{
299300
"OPTION_TENSOR_PARALLEL_DEGREE": "2",
300301
},
302+
None,
301303
),
302-
(None, None, None, None),
304+
(None, None, None, None, None),
303305
],
304306
)
305307
def test_extract_optimization_config_and_env(
306-
quantization_config, compilation_config, expected_config, expected_env
308+
quantization_config,
309+
compilation_config,
310+
expected_config,
311+
expected_quant_env,
312+
expected_compilation_env,
307313
):
308314
assert _extract_optimization_config_and_env(quantization_config, compilation_config) == (
309315
expected_config,
310-
expected_env,
316+
expected_quant_env,
317+
expected_compilation_env,
311318
)
312319

313320

0 commit comments

Comments
 (0)