@@ -119,7 +119,8 @@ class ModelTrainer(BaseModel):
119
119
from sagemaker.modules.train import ModelTrainer
120
120
from sagemaker.modules.configs import SourceCode, Compute, InputData
121
121
122
- source_code = SourceCode(source_dir="source", entry_script="train.py")
122
+ ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
123
+ source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
123
124
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
124
125
model_trainer = ModelTrainer(
125
126
training_image=training_image,
@@ -654,6 +655,7 @@ def train(
654
655
channel_name = SM_CODE ,
655
656
data_source = self .source_code .source_dir ,
656
657
key_prefix = input_data_key_prefix ,
658
+ ignore_patterns = self .source_code .ignore_patterns ,
657
659
)
658
660
final_input_data_config .append (source_code_channel )
659
661
@@ -675,6 +677,7 @@ def train(
675
677
channel_name = SM_DRIVERS ,
676
678
data_source = tmp_dir .name ,
677
679
key_prefix = input_data_key_prefix ,
680
+ ignore_patterns = self .source_code .ignore_patterns ,
678
681
)
679
682
final_input_data_config .append (sm_drivers_channel )
680
683
@@ -755,7 +758,11 @@ def train(
755
758
local_container .train (wait )
756
759
757
760
def create_input_data_channel (
758
- self , channel_name : str , data_source : DataSourceType , key_prefix : Optional [str ] = None
761
+ self ,
762
+ channel_name : str ,
763
+ data_source : DataSourceType ,
764
+ key_prefix : Optional [str ] = None ,
765
+ ignore_patterns : Optional [List [str ]] = None ,
759
766
) -> Channel :
760
767
"""Create an input data channel for the training job.
761
768
@@ -771,6 +778,10 @@ def create_input_data_channel(
771
778
772
779
If specified, local data will be uploaded to:
773
780
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
781
+ ignore_patterns: (Optional[List[str]]) :
782
+ The ignore patterns to ignore specific files/folders when uploading to S3.
783
+ If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store',
784
+ '.cache', '.ipynb_checkpoints'].
774
785
"""
775
786
channel = None
776
787
if isinstance (data_source , str ):
@@ -810,11 +821,28 @@ def create_input_data_channel(
810
821
)
811
822
if self .sagemaker_session .default_bucket_prefix :
812
823
key_prefix = f"{ self .sagemaker_session .default_bucket_prefix } /{ key_prefix } "
813
- s3_uri = self .sagemaker_session .upload_data (
814
- path = data_source ,
815
- bucket = self .sagemaker_session .default_bucket (),
816
- key_prefix = key_prefix ,
817
- )
824
+ if ignore_patterns and _is_valid_path (data_source , path_type = "Directory" ):
825
+ tmp_dir = TemporaryDirectory ()
826
+ copied_path = os .path .join (
827
+ tmp_dir .name , os .path .basename (os .path .normpath (data_source ))
828
+ )
829
+ shutil .copytree (
830
+ data_source ,
831
+ copied_path ,
832
+ dirs_exist_ok = True ,
833
+ ignore = shutil .ignore_patterns (* ignore_patterns ),
834
+ )
835
+ s3_uri = self .sagemaker_session .upload_data (
836
+ path = copied_path ,
837
+ bucket = self .sagemaker_session .default_bucket (),
838
+ key_prefix = key_prefix ,
839
+ )
840
+ else :
841
+ s3_uri = self .sagemaker_session .upload_data (
842
+ path = data_source ,
843
+ bucket = self .sagemaker_session .default_bucket (),
844
+ key_prefix = key_prefix ,
845
+ )
818
846
channel = Channel (
819
847
channel_name = channel_name ,
820
848
data_source = DataSource (
@@ -861,7 +889,9 @@ def _get_input_data_config(
861
889
channels .append (input_data )
862
890
elif isinstance (input_data , InputData ):
863
891
channel = self .create_input_data_channel (
864
- input_data .channel_name , input_data .data_source , key_prefix = key_prefix
892
+ input_data .channel_name ,
893
+ input_data .data_source ,
894
+ key_prefix = key_prefix ,
865
895
)
866
896
channels .append (channel )
867
897
else :
0 commit comments