Skip to content

Commit 2345fb0

Browse files
authored
feat: meta llama fine tuning (#4021)
* feat: meta llama fine tuning * chore: simplify code * chore: rename function to _retrieve_model_package_model_artifact_s3_uri * chore: add unit tests, fixes * fix: env var collisions, unit test * chore: de morgans law simplification * chore: address PR comments * chore: improve docstring * fix: better metadata for unit tests * fix: flake8 issues
1 parent 93e6e3c commit 2345fb0

File tree

12 files changed

+458
-94
lines changed

12 files changed

+458
-94
lines changed

src/sagemaker/jumpstart/artifacts/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,21 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""This module imports all JumpStart artifact functions from the respective sub-module."""
14-
from sagemaker.jumpstart.artifacts.prepack import _model_supports_prepacked_inference # noqa: F401
1514
from sagemaker.jumpstart.artifacts.resource_names import ( # noqa: F401
1615
_retrieve_resource_name_base,
1716
)
1817
from sagemaker.jumpstart.artifacts.incremental_training import ( # noqa: F401
1918
_model_supports_incremental_training,
2019
)
2120
from sagemaker.jumpstart.artifacts.image_uris import _retrieve_image_uri # noqa: F401
22-
from sagemaker.jumpstart.artifacts.script_uris import _retrieve_script_uri # noqa: F401
23-
from sagemaker.jumpstart.artifacts.model_uris import _retrieve_model_uri # noqa: F401
21+
from sagemaker.jumpstart.artifacts.script_uris import ( # noqa: F401
22+
_retrieve_script_uri,
23+
_model_supports_inference_script_uri,
24+
)
25+
from sagemaker.jumpstart.artifacts.model_uris import ( # noqa: F401
26+
_retrieve_model_uri,
27+
_model_supports_training_model_uri,
28+
)
2429
from sagemaker.jumpstart.artifacts.hyperparameters import ( # noqa: F401
2530
_retrieve_default_hyperparameters,
2631
)
@@ -52,4 +57,7 @@
5257
_retrieve_supported_accept_types,
5358
_retrieve_supported_content_types,
5459
)
55-
from sagemaker.jumpstart.artifacts.model_packages import _retrieve_model_package_arn # noqa: F401
60+
from sagemaker.jumpstart.artifacts.model_packages import ( # noqa: F401
61+
_retrieve_model_package_arn,
62+
_retrieve_model_package_model_artifact_s3_uri,
63+
)

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,61 @@ def _retrieve_model_package_arn(
7575
return regional_arn
7676

7777
raise NotImplementedError(f"Model Package ARN not supported for scope: '{scope}'")
78+
79+
80+
def _retrieve_model_package_model_artifact_s3_uri(
81+
model_id: str,
82+
model_version: str,
83+
region: Optional[str],
84+
scope: Optional[str] = None,
85+
tolerate_vulnerable_model: bool = False,
86+
tolerate_deprecated_model: bool = False,
87+
) -> Optional[str]:
88+
"""Retrieves s3 artifact uri associated with model package.
89+
90+
Args:
91+
model_id (str): JumpStart model ID of the JumpStart model for which to
92+
retrieve the model package artifact.
93+
model_version (str): Version of the JumpStart model for which to retrieve the
94+
model package artifact.
95+
region (Optional[str]): Region for which to retrieve the model package artifact.
96+
(Default: None).
97+
scope (Optional[str]): Scope for which to retrieve the model package artifact.
98+
(Default: None).
99+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
100+
specifications should be tolerated (exception not raised). If False, raises an
101+
exception if the script used by this version of the model has dependencies with known
102+
security vulnerabilities. (Default: False).
103+
tolerate_deprecated_model (bool): True if deprecated versions of model
104+
specifications should be tolerated (exception not raised). If False, raises
105+
an exception if the version of the model is deprecated. (Default: False).
106+
107+
Returns:
108+
str: the model package artifact uri to use for the model or None.
109+
110+
Raises:
111+
NotImplementedError: If an unsupported script is used.
112+
"""
113+
114+
if scope == JumpStartScriptScope.TRAINING:
115+
116+
if region is None:
117+
region = JUMPSTART_DEFAULT_REGION_NAME
118+
119+
model_specs = verify_model_region_and_return_specs(
120+
model_id=model_id,
121+
version=model_version,
122+
scope=scope,
123+
region=region,
124+
tolerate_vulnerable_model=tolerate_vulnerable_model,
125+
tolerate_deprecated_model=tolerate_deprecated_model,
126+
)
127+
128+
if model_specs.training_model_package_artifact_uris is None:
129+
return None
130+
131+
model_s3_uri = model_specs.training_model_package_artifact_uris.get(region)
132+
133+
return model_s3_uri
134+
135+
raise NotImplementedError(f"Model Package Artifact URI not supported for scope: '{scope}'")

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,45 @@ def _retrieve_model_uri(
9292
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
9393

9494
return model_s3_uri
95+
96+
97+
def _model_supports_training_model_uri(
98+
model_id: str,
99+
model_version: str,
100+
region: Optional[str],
101+
tolerate_vulnerable_model: bool = False,
102+
tolerate_deprecated_model: bool = False,
103+
) -> bool:
104+
"""Returns True if the model supports training with model uri field.
105+
106+
Args:
107+
model_id (str): JumpStart model ID of the JumpStart model for which to
108+
retrieve the support status for model uri with training.
109+
model_version (str): Version of the JumpStart model for which to retrieve the
110+
support status for model uri with training.
111+
region (Optional[str]): Region for which to retrieve the
112+
support status for model uri with training.
113+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
114+
specifications should be tolerated (exception not raised). If False, raises an
115+
exception if the script used by this version of the model has dependencies with known
116+
security vulnerabilities. (Default: False).
117+
tolerate_deprecated_model (bool): True if deprecated versions of model
118+
specifications should be tolerated (exception not raised). If False, raises
119+
an exception if the version of the model is deprecated. (Default: False).
120+
Returns:
121+
bool: the support status for model uri with training.
122+
"""
123+
124+
if region is None:
125+
region = JUMPSTART_DEFAULT_REGION_NAME
126+
127+
model_specs = verify_model_region_and_return_specs(
128+
model_id=model_id,
129+
version=model_version,
130+
scope=JumpStartScriptScope.TRAINING,
131+
region=region,
132+
tolerate_vulnerable_model=tolerate_vulnerable_model,
133+
tolerate_deprecated_model=tolerate_deprecated_model,
134+
)
135+
136+
return model_specs.use_training_model_artifact()

src/sagemaker/jumpstart/artifacts/prepack.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

src/sagemaker/jumpstart/artifacts/script_uris.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,45 @@ def _retrieve_script_uri(
9090
script_s3_uri = f"s3://{bucket}/{model_script_key}"
9191

9292
return script_s3_uri
93+
94+
95+
def _model_supports_inference_script_uri(
96+
model_id: str,
97+
model_version: str,
98+
region: Optional[str],
99+
tolerate_vulnerable_model: bool = False,
100+
tolerate_deprecated_model: bool = False,
101+
) -> bool:
102+
"""Returns True if the model supports inference with script uri field.
103+
104+
Args:
105+
model_id (str): JumpStart model ID of the JumpStart model for which to
106+
retrieve the support status for script uri with inference.
107+
model_version (str): Version of the JumpStart model for which to retrieve the
108+
support status for script uri with inference.
109+
region (Optional[str]): Region for which to retrieve the
110+
support status for script uri with inference.
111+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
112+
specifications should be tolerated (exception not raised). If False, raises an
113+
exception if the script used by this version of the model has dependencies with known
114+
security vulnerabilities. (Default: False).
115+
tolerate_deprecated_model (bool): True if deprecated versions of model
116+
specifications should be tolerated (exception not raised). If False, raises
117+
an exception if the version of the model is deprecated. (Default: False).
118+
Returns:
119+
bool: the support status for script uri with inference.
120+
"""
121+
122+
if region is None:
123+
region = JUMPSTART_DEFAULT_REGION_NAME
124+
125+
model_specs = verify_model_region_and_return_specs(
126+
model_id=model_id,
127+
version=model_version,
128+
scope=JumpStartScriptScope.INFERENCE,
129+
region=region,
130+
tolerate_vulnerable_model=tolerate_vulnerable_model,
131+
tolerate_deprecated_model=tolerate_deprecated_model,
132+
)
133+
134+
return model_specs.use_inference_script_uri()

src/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147

148148
JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart"
149149

150+
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY = "SageMakerGatedModelS3Uri"
150151

151152
CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP: Dict[MIMEType, SerializerType] = {
152153
MIMEType.X_IMAGE: SerializerType.RAW_BYTES,

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
model_uris,
2424
script_uris,
2525
)
26-
from sagemaker.jumpstart.artifacts import _model_supports_incremental_training
26+
from sagemaker.jumpstart.artifacts import (
27+
_model_supports_incremental_training,
28+
_retrieve_model_package_model_artifact_s3_uri,
29+
)
2730
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
2831
from sagemaker.session import Session
2932
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
@@ -37,11 +40,13 @@
3740
from sagemaker.jumpstart.artifacts import (
3841
_retrieve_estimator_init_kwargs,
3942
_retrieve_estimator_fit_kwargs,
43+
_model_supports_training_model_uri,
4044
)
4145
from sagemaker.jumpstart.constants import (
4246
JUMPSTART_DEFAULT_REGION_NAME,
4347
JUMPSTART_LOGGER,
4448
TRAINING_ENTRY_POINT_SCRIPT_NAME,
49+
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
4550
)
4651
from sagemaker.jumpstart.enums import JumpStartScriptScope
4752
from sagemaker.jumpstart.factory import model
@@ -187,6 +192,7 @@ def get_init_kwargs(
187192
estimator_init_kwargs = _add_metric_definitions_to_kwargs(estimator_init_kwargs)
188193
estimator_init_kwargs = _add_estimator_extra_kwargs(estimator_init_kwargs)
189194
estimator_init_kwargs = _add_role_to_kwargs(estimator_init_kwargs)
195+
estimator_init_kwargs = _add_env_to_kwargs(estimator_init_kwargs)
190196

191197
return estimator_init_kwargs
192198

@@ -446,32 +452,39 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
446452
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
447453
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
448454

449-
default_model_uri = model_uris.retrieve(
450-
model_scope=JumpStartScriptScope.TRAINING,
455+
if _model_supports_training_model_uri(
451456
model_id=kwargs.model_id,
452457
model_version=kwargs.model_version,
458+
region=kwargs.region,
453459
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
454460
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
455-
)
456-
457-
if (
458-
kwargs.model_uri is not None
459-
and kwargs.model_uri != default_model_uri
460-
and not _model_supports_incremental_training(
461+
):
462+
default_model_uri = model_uris.retrieve(
463+
model_scope=JumpStartScriptScope.TRAINING,
461464
model_id=kwargs.model_id,
462465
model_version=kwargs.model_version,
463-
region=kwargs.region,
464466
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
465467
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
466468
)
467-
):
468-
JUMPSTART_LOGGER.warning(
469-
"'%s' does not support incremental training but is being trained with"
470-
" non-default model artifact.",
471-
kwargs.model_id,
472-
)
473469

474-
kwargs.model_uri = kwargs.model_uri or default_model_uri
470+
if (
471+
kwargs.model_uri is not None
472+
and kwargs.model_uri != default_model_uri
473+
and not _model_supports_incremental_training(
474+
model_id=kwargs.model_id,
475+
model_version=kwargs.model_version,
476+
region=kwargs.region,
477+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
478+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
479+
)
480+
):
481+
JUMPSTART_LOGGER.warning(
482+
"'%s' does not support incremental training but is being trained with"
483+
" non-default model artifact.",
484+
kwargs.model_id,
485+
)
486+
487+
kwargs.model_uri = kwargs.model_uri or default_model_uri
475488

476489
return kwargs
477490

@@ -501,6 +514,31 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart
501514
return kwargs
502515

503516

517+
def _add_env_to_kwargs(
518+
kwargs: JumpStartEstimatorInitKwargs,
519+
) -> JumpStartEstimatorInitKwargs:
520+
"""Sets environment in kwargs based on default or override, returns full kwargs."""
521+
522+
model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri(
523+
model_id=kwargs.model_id,
524+
model_version=kwargs.model_version,
525+
region=kwargs.region,
526+
scope=JumpStartScriptScope.TRAINING,
527+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
528+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
529+
)
530+
531+
if model_package_artifact_uri:
532+
if kwargs.environment is None:
533+
kwargs.environment = {}
534+
kwargs.environment = {
535+
**{SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY: model_package_artifact_uri},
536+
**kwargs.environment,
537+
}
538+
539+
return kwargs
540+
541+
504542
def _add_entry_point_to_kwargs(
505543
kwargs: JumpStartEstimatorInitKwargs,
506544
) -> JumpStartEstimatorInitKwargs:

0 commit comments

Comments
 (0)