-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathsession.py
3782 lines (3204 loc) · 159 KB
/
session.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import, print_function
import json
import logging
import os
import re
import sys
import time
import warnings
import boto3
import botocore.config
from botocore.exceptions import ClientError
import six
import sagemaker.logs
from sagemaker import vpc_utils
# import s3_input for backward compatibility
from sagemaker.inputs import s3_input # noqa # pylint: disable=unused-import
from sagemaker.user_agent import prepend_user_agent, update_sdk_metrics
from sagemaker.utils import (
name_from_image,
secondary_training_status_changed,
secondary_training_status_message,
sts_regional_endpoint,
)
from sagemaker import exceptions
LOGGER = logging.getLogger("sagemaker")
NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"
_STATUS_CODE_TABLE = {
"COMPLETED": "Completed",
"INPROGRESS": "InProgress",
"FAILED": "Failed",
"STOPPED": "Stopped",
"STOPPING": "Stopping",
"STARTING": "Starting",
}
class LogState(object):
"""Placeholder docstring"""
STARTING = 1
WAIT_IN_PROGRESS = 2
TAILING = 3
JOB_COMPLETE = 4
COMPLETE = 5
class Session(object): # pylint: disable=too-many-public-methods
"""Manage interactions with the Amazon SageMaker APIs and any other AWS services needed.
This class provides convenient methods for manipulating entities and resources that Amazon
SageMaker uses, such as training jobs, endpoints, and input datasets in S3.
AWS service calls are delegated to an underlying Boto3 session, which by default
is initialized using the AWS configuration chain. When you make an Amazon SageMaker API call
that accesses an S3 bucket location and one is not specified, the ``Session`` creates a default
bucket based on a naming convention which includes the current AWS account ID.
"""
def __init__(
self,
boto_session=None,
sagemaker_client=None,
sagemaker_runtime_client=None,
default_bucket=None,
):
"""Initialize a SageMaker ``Session``.
Args:
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
calls are delegated to (default: None). If not provided, one is created with
default AWS configuration chain.
sagemaker_client (boto3.SageMaker.Client): Client which makes Amazon SageMaker service
calls other than ``InvokeEndpoint`` (default: None). Estimators created using this
``Session`` use this client. If not provided, one will be created using this
instance's ``boto_session``.
sagemaker_runtime_client (boto3.SageMakerRuntime.Client): Client which makes
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
using this ``Session`` use this client. If not provided, one will be created using
this instance's ``boto_session``.
default_bucket (str): The default Amazon S3 bucket to be used by this session.
This will be created the next time an Amazon S3 bucket is needed (by calling
:func:`default_bucket`).
If not provided, a default bucket will be created based on the following format:
"sagemaker-{region}-{aws-account-id}".
Example: "sagemaker-my-custom-bucket".
"""
self._default_bucket = None
self._default_bucket_name_override = default_bucket
self.s3_resource = None
self.s3_client = None
self.config = None
self._initialize(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=sagemaker_runtime_client,
)
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
"""Initialize this SageMaker Session.
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Sets the region_name.
"""
self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session()
self._region_name = self.boto_session.region_name
if self._region_name is None:
raise ValueError(
"Must setup local AWS configuration with a region supported by SageMaker."
)
self.sagemaker_client = sagemaker_client or self.boto_session.client("sagemaker")
self._sagemaker_client_default_user_agent = prepend_user_agent(self.sagemaker_client)
if sagemaker_runtime_client is not None:
self.sagemaker_runtime_client = sagemaker_runtime_client
else:
config = botocore.config.Config(read_timeout=80)
self.sagemaker_runtime_client = self.boto_session.client(
"runtime.sagemaker", config=config
)
self._sagemaker_runtime_client_default_user_agent = prepend_user_agent(
self.sagemaker_runtime_client
)
self.local_mode = False
def sagemaker_client_sdk_metrics(self, sdk_metrics):
"""
sdk_metrics (sagemaker.SDKMetrics): An object that defines
the Python SDK telemetry metrics used to track Python SDK usage
Returns:
boto3.SageMaker.Client: Client which makes Amazon SageMaker service
calls other than ``InvokeEndpoint`` with updated user_agent string
with sdk_metrics.
"""
update_sdk_metrics(
self.sagemaker_client,
self._sagemaker_client_default_user_agent,
sdk_metrics
)
return self.sagemaker_client
@property
def boto_region_name(self):
"""Placeholder docstring"""
return self._region_name
def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None):
"""Upload local file or directory to S3.
If a single file is specified for upload, the resulting S3 object key is
``{key_prefix}/{filename}`` (filename does not include the local path, if any specified).
If a directory is specified for upload, the API uploads all content, recursively,
preserving relative structure of subdirectories. The resulting object key names are:
``{key_prefix}/{relative_subdirectory_path}/filename``.
Args:
path (str): Path (absolute or relative) of local file or directory to upload.
bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
default bucket of the ``Session`` is used (if default bucket does not exist, the
``Session`` creates it).
key_prefix (str): Optional S3 object key name prefix (default: 'data'). S3 uses the
prefix to create a directory structure for the bucket content that it display in
the S3 console.
extra_args (dict): Optional extra arguments that may be passed to the upload operation.
Similar to ExtraArgs parameter in S3 upload_file function. Please refer to the
ExtraArgs parameter documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-uploading-files.html#the-extraargs-parameter
Returns:
str: The S3 URI of the uploaded file(s). If a file is specified in the path argument,
the URI format is: ``s3://{bucket name}/{key_prefix}/{original_file_name}``.
If a directory is specified in the path argument, the URI format is
``s3://{bucket name}/{key_prefix}``.
"""
# Generate a tuple for each file that we want to upload of the form (local_path, s3_key).
files = []
key_suffix = None
if os.path.isdir(path):
for dirpath, _, filenames in os.walk(path):
for name in filenames:
local_path = os.path.join(dirpath, name)
s3_relative_prefix = (
"" if path == dirpath else os.path.relpath(dirpath, start=path) + "/"
)
s3_key = "{}/{}{}".format(key_prefix, s3_relative_prefix, name)
files.append((local_path, s3_key))
else:
_, name = os.path.split(path)
s3_key = "{}/{}".format(key_prefix, name)
files.append((path, s3_key))
key_suffix = name
bucket = bucket or self.default_bucket()
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
for local_path, s3_key in files:
s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args)
s3_uri = "s3://{}/{}".format(bucket, key_prefix)
# If a specific file was used as input (instead of a directory), we return the full S3 key
# of the uploaded object. This prevents unintentionally using other files under the same
# prefix during training.
if key_suffix:
s3_uri = "{}/{}".format(s3_uri, key_suffix)
return s3_uri
def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
"""Upload a string as a file body.
Args:
body (str): String representing the body of the file.
bucket (str): Name of the S3 Bucket to upload to (default: None). If not specified, the
default bucket of the ``Session`` is used (if default bucket does not exist, the
``Session`` creates it).
key (str): S3 object key. This is the s3 path to the file.
kms_key (str): The KMS key to use for encrypting the file.
Returns:
str: The S3 URI of the uploaded file.
The URI format is: ``s3://{bucket name}/{key}``.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
s3_object = s3.Object(bucket_name=bucket, key=key)
if kms_key is not None:
s3_object.put(Body=body, SSEKMSKeyId=kms_key, ServerSideEncryption="aws:kms")
else:
s3_object.put(Body=body)
s3_uri = "s3://{}/{}".format(bucket, key)
return s3_uri
def download_data(self, path, bucket, key_prefix="", extra_args=None):
"""Download file or directory from S3.
Args:
path (str): Local path where the file or directory should be downloaded to.
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): Optional S3 object key name prefix.
extra_args (dict): Optional extra arguments that may be passed to the
download operation. Please refer to the ExtraArgs parameter in the boto3
documentation here:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html
Returns:
"""
# Initialize the S3 client.
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client
# Initialize the variables used to loop through the contents of the S3 bucket.
keys = []
next_token = ""
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
# Loop through the contents of the bucket, 1,000 objects at a time. Gathering all keys into
# a "keys" list.
while next_token is not None:
request_parameters = base_parameters.copy()
if next_token != "":
request_parameters.update({"ContinuationToken": next_token})
response = s3.list_objects_v2(**request_parameters)
contents = response.get("Contents")
# For each object, save its key or directory.
for s3_object in contents:
key = s3_object.get("Key")
keys.append(key)
next_token = response.get("NextContinuationToken")
# For each object key, create the directory on the local machine if needed, and then
# download the file.
for key in keys:
tail_s3_uri_path = os.path.basename(key_prefix)
if not os.path.splitext(key_prefix)[1]:
tail_s3_uri_path = os.path.relpath(key, key_prefix)
destination_path = os.path.join(path, tail_s3_uri_path)
if not os.path.exists(os.path.dirname(destination_path)):
os.makedirs(os.path.dirname(destination_path))
s3.download_file(
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
)
def read_s3_file(self, bucket, key_prefix):
"""Read a single file from S3.
Args:
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): S3 object key name prefix.
Returns:
str: The body of the s3 file as a string.
"""
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client
# Explicitly passing a None kms_key to boto3 throws a validation error.
s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
return s3_object["Body"].read().decode("utf-8")
def list_s3_files(self, bucket, key_prefix):
"""Lists the S3 files given an S3 bucket and key.
Args:
bucket (str): Name of the S3 Bucket to download from.
key_prefix (str): S3 object key name prefix.
Returns:
[str]: The list of files at the S3 path.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource
s3_bucket = s3.Bucket(name=bucket)
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
return [s3_object.key for s3_object in s3_objects]
def default_bucket(self):
"""Return the name of the default bucket to use in relevant Amazon SageMaker interactions.
Returns:
str: The name of the default bucket, which is of the form:
``sagemaker-{region}-{AWS account ID}``.
"""
if self._default_bucket:
return self._default_bucket
region = self.boto_session.region_name
default_bucket = self._default_bucket_name_override
if not default_bucket:
account = self.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
).get_caller_identity()["Account"]
default_bucket = "sagemaker-{}-{}".format(region, account)
self._create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=region)
self._default_bucket = default_bucket
return self._default_bucket
def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
"""Creates an S3 Bucket if it does not exist.
Also swallows a few common exceptions that indicate that the bucket already exists or
that it is being created.
Args:
bucket_name (str): Name of the S3 bucket to be created.
region (str): The region in which to create the bucket.
Raises:
botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket
creation.
If the exception is due to the bucket already existing or
already being created, no exception is raised.
"""
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=region)
else:
s3 = self.s3_resource
bucket = s3.Bucket(name=bucket_name)
if bucket.creation_date is None:
try:
if region == "us-east-1":
# 'us-east-1' cannot be specified because it is the default region:
# https://github.com/boto/boto3/issues/125
s3.create_bucket(Bucket=bucket_name)
else:
s3.create_bucket(
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
)
LOGGER.info("Created S3 bucket: %s", bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
if error_code == "BucketAlreadyOwnedByYou":
pass
elif (
error_code == "OperationAborted"
and "conflicting conditional operation" in message
):
# If this bucket is already being concurrently created, we don't need to create
# it again.
pass
else:
raise
def train( # noqa: C901
self,
input_mode,
input_config,
role,
job_name,
output_config,
resource_config,
vpc_config,
hyperparameters,
stop_condition,
tags,
metric_definitions,
enable_network_isolation=False,
image=None,
algorithm_arn=None,
encrypt_inter_container_traffic=False,
train_use_spot_instances=False,
checkpoint_s3_uri=None,
checkpoint_local_path=None,
experiment_config=None,
debugger_rule_configs=None,
debugger_hook_config=None,
tensorboard_output_config=None,
enable_sagemaker_metrics=None,
sdk_metrics=None,
):
"""Create an Amazon SageMaker training job.
Args:
input_mode (str): The input mode that the algorithm supports. Valid modes:
* 'File' - Amazon SageMaker copies the training dataset from the S3 location to
a directory in the Docker container.
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
Unix-named pipe.
input_config (list): A list of Channel objects. Each channel is a named input source.
Please refer to the format details described:
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
jobs and APIs that create Amazon SageMaker endpoints use this role to access
training data and model artifacts. You must grant sufficient permissions to this
role.
job_name (str): Name of the training job being created.
output_config (dict): The S3 URI where you want to store the training results and
optional KMS key ID.
resource_config (dict): Contains values for ResourceConfig:
* instance_count (int): Number of EC2 instances to use for training.
The key in resource_config is 'InstanceCount'.
* instance_type (str): Type of EC2 instance to use for training, for example,
'ml.c4.xlarge'. The key in resource_config is 'InstanceType'.
vpc_config (dict): Contains values for VpcConfig:
* subnets (list[str]): List of subnet ids.
The key in vpc_config is 'Subnets'.
* security_group_ids (list[str]): List of security group ids.
The key in vpc_config is 'SecurityGroupIds'.
hyperparameters (dict): Hyperparameters for model training. The hyperparameters are
made accessible as a dict[str, str] to the training code on SageMaker. For
convenience, this accepts other types for keys and values, but ``str()`` will be
called to convert them before training.
stop_condition (dict): Defines when training shall finish. Contains entries that can
be understood by the service like ``MaxRuntimeInSeconds``.
tags (list[dict]): List of tags for labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s)
used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
the name of the metric, and 'Regex' for the regular expression used to extract the
metric from the logs.
enable_network_isolation (bool): Whether to request for the training job to run with
network isolation or not.
image (str): Docker image containing training code.
algorithm_arn (str): Algorithm Arn from Marketplace.
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
containers is encrypted for the training job (default: ``False``).
train_use_spot_instances (bool): whether to use spot instances for training.
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
that the algorithm persists (if any) during training. (default:
``None``).
checkpoint_local_path (str): The local path that the algorithm
writes its checkpoints to. SageMaker will persist all files
under this path to `checkpoint_s3_uri` continually during
training. On job startup the reverse happens - data from the
s3 location is downloaded to this path before the algorithm is
started. If the path is unset then SageMaker assumes the
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
experiment_config (dict): Experiment management configuration. Dictionary contains
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
(default: ``None``)
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default: ``None``).
Returns:
str: ARN of the training job, if it is created.
"""
train_request = {
"AlgorithmSpecification": {"TrainingInputMode": input_mode},
"OutputDataConfig": output_config,
"TrainingJobName": job_name,
"StoppingCondition": stop_condition,
"ResourceConfig": resource_config,
"RoleArn": role,
}
if image and algorithm_arn:
raise ValueError(
"image and algorithm_arn are mutually exclusive."
"Both were provided: image: %s algorithm_arn: %s" % (image, algorithm_arn)
)
if image is None and algorithm_arn is None:
raise ValueError("either image or algorithm_arn is required. None was provided.")
if image is not None:
train_request["AlgorithmSpecification"]["TrainingImage"] = image
if algorithm_arn is not None:
train_request["AlgorithmSpecification"]["AlgorithmName"] = algorithm_arn
if input_config is not None:
train_request["InputDataConfig"] = input_config
if metric_definitions is not None:
train_request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions
if enable_sagemaker_metrics is not None:
train_request["AlgorithmSpecification"][
"EnableSageMakerMetricsTimeSeries"
] = enable_sagemaker_metrics
if hyperparameters and len(hyperparameters) > 0:
train_request["HyperParameters"] = hyperparameters
if tags is not None:
train_request["Tags"] = tags
if vpc_config is not None:
train_request["VpcConfig"] = vpc_config
if experiment_config and len(experiment_config) > 0:
train_request["ExperimentConfig"] = experiment_config
if enable_network_isolation:
train_request["EnableNetworkIsolation"] = enable_network_isolation
if encrypt_inter_container_traffic:
train_request["EnableInterContainerTrafficEncryption"] = encrypt_inter_container_traffic
if train_use_spot_instances:
train_request["EnableManagedSpotTraining"] = train_use_spot_instances
if checkpoint_s3_uri:
checkpoint_config = {"S3Uri": checkpoint_s3_uri}
if checkpoint_local_path:
checkpoint_config["LocalPath"] = checkpoint_local_path
train_request["CheckpointConfig"] = checkpoint_config
if debugger_rule_configs is not None:
train_request["DebugRuleConfigurations"] = debugger_rule_configs
if debugger_hook_config is not None:
train_request["DebugHookConfig"] = debugger_hook_config
if tensorboard_output_config is not None:
train_request["TensorBoardOutputConfig"] = tensorboard_output_config
LOGGER.info("Creating training-job with name: %s", job_name)
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
self.sagemaker_client_sdk_metrics(sdk_metrics=sdk_metrics)\
.create_training_job(**train_request)
def process(
self,
inputs,
output_config,
job_name,
resources,
stopping_condition,
app_specification,
environment,
network_config,
role_arn,
tags,
experiment_config=None,
):
"""Create an Amazon SageMaker processing job.
Args:
inputs ([dict]): List of up to 10 ProcessingInput dictionaries.
output_config (dict): A config dictionary, which contains a list of up
to 10 ProcessingOutput dictionaries, as well as an optional KMS key ID.
job_name (str): The name of the processing job. The name must be unique
within an AWS Region in an AWS account. Names should have minimum
length of 1 and maximum length of 63 characters.
resources (dict): Encapsulates the resources, including ML instances
and storage, to use for the processing job.
stopping_condition (dict[str,int]): Specifies a limit to how long
the processing job can run, in seconds.
app_specification (dict[str,str]): Configures the processing job to
run the given image. Details are in the processing container
specification.
environment (dict): Environment variables to start the processing
container with.
network_config (dict): Specifies networking options, such as network
traffic encryption between processing containers, whether to allow
inbound and outbound network calls to and from processing containers,
and VPC subnets and security groups to use for VPC-enabled processing
jobs.
role_arn (str): The Amazon Resource Name (ARN) of an IAM role that
Amazon SageMaker can assume to perform tasks on your behalf.
tags ([dict[str,str]]): A list of dictionaries containing key-value
pairs.
experiment_config (dict): Experiment management configuration. Dictionary contains
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
(default: ``None``)
"""
process_request = {
"ProcessingJobName": job_name,
"ProcessingResources": resources,
"AppSpecification": app_specification,
"RoleArn": role_arn,
}
if inputs:
process_request["ProcessingInputs"] = inputs
if output_config["Outputs"]:
process_request["ProcessingOutputConfig"] = output_config
if environment is not None:
process_request["Environment"] = environment
if network_config is not None:
process_request["NetworkConfig"] = network_config
if stopping_condition is not None:
process_request["StoppingCondition"] = stopping_condition
if tags is not None:
process_request["Tags"] = tags
if experiment_config:
process_request["ExperimentConfig"] = experiment_config
LOGGER.info("Creating processing-job with name %s", job_name)
LOGGER.debug("process request: %s", json.dumps(process_request, indent=4))
self.sagemaker_client_sdk_metrics([]).create_processing_job(**process_request)
def create_monitoring_schedule(
self,
monitoring_schedule_name,
schedule_expression,
statistics_s3_uri,
constraints_s3_uri,
monitoring_inputs,
monitoring_output_config,
instance_count,
instance_type,
volume_size_in_gb,
volume_kms_key,
image_uri,
entrypoint,
arguments,
record_preprocessor_source_uri,
post_analytics_processor_source_uri,
max_runtime_in_seconds,
environment,
network_config,
role_arn,
tags,
):
"""Create an Amazon SageMaker monitoring schedule.
Args:
monitoring_schedule_name (str): The name of the monitoring schedule. The name must be
unique within an AWS Region in an AWS account. Names should have a minimum length
of 1 and a maximum length of 63 characters.
schedule_expression (str): The cron expression that dictates the monitoring execution
schedule.
statistics_s3_uri (str): The S3 uri of the statistics file to use.
constraints_s3_uri (str): The S3 uri of the constraints file to use.
monitoring_inputs ([dict]): List of MonitoringInput dictionaries.
monitoring_output_config (dict): A config dictionary, which contains a list of
MonitoringOutput dictionaries, as well as an optional KMS key ID.
instance_count (int): The number of instances to run.
instance_type (str): The type of instance to run.
volume_size_in_gb (int): Size of the volume in GB.
volume_kms_key (str): KMS key to use when encrypting the volume.
image_uri (str): The image uri to use for monitoring executions.
entrypoint (str): The entrypoint to the monitoring execution image.
arguments (str): The arguments to pass to the monitoring execution image.
record_preprocessor_source_uri (str or None): The S3 uri that points to the script that
pre-processes the dataset (only applicable to first-party images).
post_analytics_processor_source_uri (str or None): The S3 uri that points to the script
that post-processes the dataset (only applicable to first-party images).
max_runtime_in_seconds (int): Specifies a limit to how long
the processing job can run, in seconds.
environment (dict): Environment variables to start the monitoring execution
container with.
network_config (dict): Specifies networking options, such as network
traffic encryption between processing containers, whether to allow
inbound and outbound network calls to and from processing containers,
and VPC subnets and security groups to use for VPC-enabled processing
jobs.
role_arn (str): The Amazon Resource Name (ARN) of an IAM role that
Amazon SageMaker can assume to perform tasks on your behalf.
tags ([dict[str,str]]): A list of dictionaries containing key-value
pairs.
"""
monitoring_schedule_request = {
"MonitoringScheduleName": monitoring_schedule_name,
"MonitoringScheduleConfig": {
"MonitoringJobDefinition": {
"MonitoringInputs": monitoring_inputs,
"MonitoringResources": {
"ClusterConfig": {
"InstanceCount": instance_count,
"InstanceType": instance_type,
"VolumeSizeInGB": volume_size_in_gb,
}
},
"MonitoringAppSpecification": {"ImageUri": image_uri},
"RoleArn": role_arn,
}
},
}
if schedule_expression is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["ScheduleConfig"] = {
"ScheduleExpression": schedule_expression
}
if monitoring_output_config is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringOutputConfig"
] = monitoring_output_config
if statistics_s3_uri is not None or constraints_s3_uri is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"BaselineConfig"
] = {}
if statistics_s3_uri is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"BaselineConfig"
]["StatisticsResource"] = {"S3Uri": statistics_s3_uri}
if constraints_s3_uri is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"BaselineConfig"
]["ConstraintsResource"] = {"S3Uri": constraints_s3_uri}
if record_preprocessor_source_uri is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringAppSpecification"
]["RecordPreprocessorSourceUri"] = record_preprocessor_source_uri
if post_analytics_processor_source_uri is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringAppSpecification"
]["PostAnalyticsProcessorSourceUri"] = post_analytics_processor_source_uri
if entrypoint is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringAppSpecification"
]["ContainerEntrypoint"] = entrypoint
if arguments is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringAppSpecification"
]["ContainerArguments"] = arguments
if volume_kms_key is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringResources"
]["ClusterConfig"]["VolumeKmsKeyId"] = volume_kms_key
if max_runtime_in_seconds is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"StoppingCondition"
] = {"MaxRuntimeInSeconds": max_runtime_in_seconds}
if environment is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"Environment"
] = environment
if network_config is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"NetworkConfig"
] = network_config
if tags is not None:
monitoring_schedule_request["Tags"] = tags
LOGGER.info("Creating monitoring schedule name %s.", monitoring_schedule_name)
LOGGER.debug(
"monitoring_schedule_request= %s", json.dumps(monitoring_schedule_request, indent=4)
)
self.sagemaker_client.create_monitoring_schedule(**monitoring_schedule_request)
def update_monitoring_schedule(
self,
monitoring_schedule_name,
schedule_expression=None,
statistics_s3_uri=None,
constraints_s3_uri=None,
monitoring_inputs=None,
monitoring_output_config=None,
instance_count=None,
instance_type=None,
volume_size_in_gb=None,
volume_kms_key=None,
image_uri=None,
entrypoint=None,
arguments=None,
record_preprocessor_source_uri=None,
post_analytics_processor_source_uri=None,
max_runtime_in_seconds=None,
environment=None,
network_config=None,
role_arn=None,
):
"""Update an Amazon SageMaker monitoring schedule.
Args:
monitoring_schedule_name (str): The name of the monitoring schedule. The name must be
unique within an AWS Region in an AWS account. Names should have a minimum length
of 1 and a maximum length of 63 characters.
schedule_expression (str): The cron expression that dictates the monitoring execution
schedule.
statistics_s3_uri (str): The S3 uri of the statistics file to use.
constraints_s3_uri (str): The S3 uri of the constraints file to use.
monitoring_inputs ([dict]): List of MonitoringInput dictionaries.
monitoring_output_config (dict): A config dictionary, which contains a list of
MonitoringOutput dictionaries, as well as an optional KMS key ID.
instance_count (int): The number of instances to run.
instance_type (str): The type of instance to run.
volume_size_in_gb (int): Size of the volume in GB.
volume_kms_key (str): KMS key to use when encrypting the volume.
image_uri (str): The image uri to use for monitoring executions.
entrypoint (str): The entrypoint to the monitoring execution image.
arguments (str): The arguments to pass to the monitoring execution image.
record_preprocessor_source_uri (str or None): The S3 uri that points to the script that
pre-processes the dataset (only applicable to first-party images).
post_analytics_processor_source_uri (str or None): The S3 uri that points to the script
that post-processes the dataset (only applicable to first-party images).
max_runtime_in_seconds (int): Specifies a limit to how long
the processing job can run, in seconds.
environment (dict): Environment variables to start the monitoring execution
container with.
network_config (dict): Specifies networking options, such as network
traffic encryption between processing containers, whether to allow
inbound and outbound network calls to and from processing containers,
and VPC subnets and security groups to use for VPC-enabled processing
jobs.
role_arn (str): The Amazon Resource Name (ARN) of an IAM role that
Amazon SageMaker can assume to perform tasks on your behalf.
tags ([dict[str,str]]): A list of dictionaries containing key-value
pairs.
"""
existing_desc = self.sagemaker_client.describe_monitoring_schedule(
MonitoringScheduleName=monitoring_schedule_name
)
existing_schedule_config = None
if (
existing_desc.get("MonitoringScheduleConfig") is not None
and existing_desc["MonitoringScheduleConfig"].get("ScheduleConfig") is not None
and existing_desc["MonitoringScheduleConfig"]["ScheduleConfig"]["ScheduleExpression"]
is not None
):
existing_schedule_config = existing_desc["MonitoringScheduleConfig"]["ScheduleConfig"][
"ScheduleExpression"
]
request_schedule_expression = schedule_expression or existing_schedule_config
request_monitoring_inputs = (
monitoring_inputs
or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringInputs"
]
)
request_instance_count = (
instance_count
or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringResources"
]["ClusterConfig"]["InstanceCount"]
)
request_instance_type = (
instance_type
or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringResources"
]["ClusterConfig"]["InstanceType"]
)
request_volume_size_in_gb = (
volume_size_in_gb
or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringResources"
]["ClusterConfig"]["VolumeSizeInGB"]
)
request_image_uri = (
image_uri
or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringAppSpecification"
]["ImageUri"]
)
request_role_arn = (
role_arn
or existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"]["RoleArn"]
)
monitoring_schedule_request = {
"MonitoringScheduleName": monitoring_schedule_name,
"MonitoringScheduleConfig": {
"MonitoringJobDefinition": {
"MonitoringInputs": request_monitoring_inputs,
"MonitoringResources": {
"ClusterConfig": {
"InstanceCount": request_instance_count,
"InstanceType": request_instance_type,
"VolumeSizeInGB": request_volume_size_in_gb,
}
},
"MonitoringAppSpecification": {"ImageUri": request_image_uri},
"RoleArn": request_role_arn,
}
},
}
if existing_schedule_config is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["ScheduleConfig"] = {
"ScheduleExpression": request_schedule_expression
}
existing_monitoring_output_config = existing_desc["MonitoringScheduleConfig"][
"MonitoringJobDefinition"
].get("MonitoringOutputConfig")
if monitoring_output_config is not None or existing_monitoring_output_config is not None:
monitoring_schedule_request["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"MonitoringOutputConfig"
] = (monitoring_output_config or existing_monitoring_output_config)
existing_statistics_s3_uri = None
existing_constraints_s3_uri = None
if (
existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"].get(
"BaselineConfig"
)
is not None
):
if (
existing_desc["MonitoringScheduleConfig"]["MonitoringJobDefinition"][
"BaselineConfig"
].get("StatisticsResource")
is not None
):
existing_statistics_s3_uri = existing_desc["MonitoringScheduleConfig"][