Skip to content
6 changes: 5 additions & 1 deletion src/sagemaker/modules/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from __future__ import absolute_import

from typing import Optional, Union
from typing import Optional, Union, List
from pydantic import BaseModel, model_validator, ConfigDict

import sagemaker_core.shapes as shapes
Expand Down Expand Up @@ -96,12 +96,16 @@ class SourceCode(BaseConfig):
command (Optional[str]):
The command(s) to execute in the training job container. Example: "python my_script.py".
If not specified, entry_script must be provided.
ignore_patterns: (Optional[List[str]]) :
The ignore patterns to ignore specific files/folders when uploading to S3. If not specified,
default to: ['.env', '.git', '__pycache__', '.DS_Store'].
"""

source_dir: Optional[str] = None
requirements: Optional[str] = None
entry_script: Optional[str] = None
command: Optional[str] = None
ignore_patterns: Optional[List[str]] = [".env", ".git", "__pycache__", ".DS_Store"]


class Compute(shapes.ResourceConfig):
Expand Down
45 changes: 37 additions & 8 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class ModelTrainer(BaseModel):
from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.configs import SourceCode, Compute, InputData

source_code = SourceCode(source_dir="source", entry_script="train.py")
ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
model_trainer = ModelTrainer(
training_image=training_image,
Expand Down Expand Up @@ -654,6 +655,7 @@ def train(
channel_name=SM_CODE,
data_source=self.source_code.source_dir,
key_prefix=input_data_key_prefix,
ignore_patterns=self.source_code.ignore_patterns,
)
final_input_data_config.append(source_code_channel)

Expand All @@ -675,6 +677,7 @@ def train(
channel_name=SM_DRIVERS,
data_source=tmp_dir.name,
key_prefix=input_data_key_prefix,
ignore_patterns=self.source_code.ignore_patterns,
)
final_input_data_config.append(sm_drivers_channel)

Expand Down Expand Up @@ -755,7 +758,11 @@ def train(
local_container.train(wait)

def create_input_data_channel(
self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None
self,
channel_name: str,
data_source: DataSourceType,
key_prefix: Optional[str] = None,
ignore_patterns: Optional[List[str]] = None,
) -> Channel:
"""Create an input data channel for the training job.

Expand All @@ -771,6 +778,9 @@ def create_input_data_channel(

If specified, local data will be uploaded to:
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
ignore_patterns: (Optional[List[str]]) :
The ignore patterns to ignore specific files/folders when uploading to S3.
If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store'].
"""
channel = None
if isinstance(data_source, str):
Expand Down Expand Up @@ -810,11 +820,28 @@ def create_input_data_channel(
)
if self.sagemaker_session.default_bucket_prefix:
key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}"
s3_uri = self.sagemaker_session.upload_data(
path=data_source,
bucket=self.sagemaker_session.default_bucket(),
key_prefix=key_prefix,
)
if ignore_patterns and _is_valid_path(data_source, path_type="Directory"):
tmp_dir = TemporaryDirectory()
copied_path = os.path.join(
tmp_dir.name, os.path.basename(os.path.normpath(data_source))
)
shutil.copytree(
data_source,
copied_path,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns(*ignore_patterns),
)
s3_uri = self.sagemaker_session.upload_data(
path=copied_path,
bucket=self.sagemaker_session.default_bucket(),
key_prefix=key_prefix,
)
else:
s3_uri = self.sagemaker_session.upload_data(
path=data_source,
bucket=self.sagemaker_session.default_bucket(),
key_prefix=key_prefix,
)
channel = Channel(
channel_name=channel_name,
data_source=DataSource(
Expand Down Expand Up @@ -861,7 +888,9 @@ def _get_input_data_config(
channels.append(input_data)
elif isinstance(input_data, InputData):
channel = self.create_input_data_channel(
input_data.channel_name, input_data.data_source, key_prefix=key_prefix
input_data.channel_name,
input_data.data_source,
key_prefix=key_prefix,
)
channels.append(channel)
else:
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/sagemaker/modules/train/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ def model_trainer():
},
"should_throw": False,
},
{
"init_params": {
"training_image": DEFAULT_IMAGE,
"source_code": SourceCode(
source_dir=DEFAULT_SOURCE_DIR,
command="python custom_script.py",
ignore_patterns=["data"],
),
},
"should_throw": False,
},
],
ids=[
"no_params",
Expand All @@ -213,6 +224,7 @@ def model_trainer():
"supported_source_code_local_tar_file",
"supported_source_code_s3_dir",
"supported_source_code_s3_tar_file",
"supported_source_code_ignore_patterns",
],
)
def test_model_trainer_param_validation(test_case, modules_session):
Expand Down