Skip to content

Commit 97787bf

Browse files
authored
Merge pull request #61 from yosefbs/master
Support china regions
2 parents c5207b1 + 0644f34 commit 97787bf

File tree

4 files changed

+70
-29
lines changed

4 files changed

+70
-29
lines changed

sagemaker_run_notebook/container_build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def create_project(repo_name, role, zipfile, base_image=default_base):
3838
sts = session.client("sts")
3939
identity = sts.get_caller_identity()
4040
account = identity["Account"]
41+
partition = identity["Arn"].split(':')[1]
4142
args = {
4243
"name": f"create-sagemaker-container-{repo_name}",
4344
"description": f"Build the container {repo_name} for running notebooks in SageMaker",
@@ -56,7 +57,7 @@ def create_project(repo_name, role, zipfile, base_image=default_base):
5657
],
5758
"privilegedMode": True,
5859
},
59-
"serviceRole": f"arn:aws:iam::{account}:role/{role}",
60+
"serviceRole": f"arn:{partition}:iam::{account}:role/{role}",
6061
}
6162

6263
response = client.create_project(**args)

sagemaker_run_notebook/lambda_function.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,21 @@ def execute_notebook(
2020
):
2121
session = ensure_session()
2222
region = session.region_name
23-
24-
account = session.client("sts").get_caller_identity()["Account"]
23+
caller_id=session.client("sts").get_caller_identity()
24+
partition = caller_id["Arn"].split(':')[1]
25+
account = caller_id["Account"]
26+
domain = domain_for_region(region)
2527
if not image:
2628
image = "notebook-runner"
2729
if "/" not in image:
28-
image = f"{account}.dkr.ecr.{region}.amazonaws.com/{image}"
30+
image = f"{account}.dkr.ecr.{region}.{domain}/{image}"
2931
if ":" not in image:
3032
image = image + ":latest"
3133

3234
if not role:
3335
role = f"BasicExecuteNotebookRole-{region}"
3436
if "/" not in role:
35-
role = f"arn:aws:iam::{account}:role/{role}"
37+
role = f"arn:{partition}:iam::{account}:role/{role}"
3638

3739
if output_prefix is None:
3840
output_prefix = os.path.dirname(input_path)
@@ -149,6 +151,21 @@ def ensure_session(session=None):
149151
session = boto3.session.Session()
150152
return session
151153

154+
def domain_for_region(region):
155+
"""Get the DNS suffix for the given region.
156+
Args:
157+
region (str): AWS region name
158+
Returns:
159+
str: the DNS suffix
160+
"""
161+
if region.startswith("us-iso-"):
162+
return "c2s.ic.gov"
163+
if region.startswith("us-isob-"):
164+
return "sc2s.sgov.gov"
165+
if region.startswith("cn-"):
166+
return "amazonaws.com.cn"
167+
return "amazonaws.com"
168+
152169

153170
def lambda_handler(event, context):
154171
job = execute_notebook(

sagemaker_run_notebook/run_notebook.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
import botocore
2828
import boto3
2929

30-
from .utils import default_bucket, get_execution_role
30+
from .utils import default_bucket, domain_for_region, get_execution_role
3131

3232
abbrev_image_pat = re.compile(
33-
r"(?P<account>\d+).dkr.ecr.(?P<region>[^.]+).amazonaws.com/(?P<image>[^:/]+)(?P<tag>:[^:]+)?"
33+
r"(?P<account>\d+).dkr.ecr.(?P<region>[^.]+).(amazonaws.com|amazonaws.com.cn)/(?P<image>[^:/]+)(?P<tag>:[^:]+)?"
3434
)
3535

3636

@@ -46,7 +46,8 @@ def abbreviate_image(image):
4646
return image
4747

4848

49-
abbrev_role_pat = re.compile(r"arn:aws:iam::(?P<account>\d+):role/(?P<name>[^/]+)")
49+
abbrev_role_pat = re.compile(
50+
r"arn:aws([^:]*):iam::(?P<account>\d+):role/(?P<name>[^/]+)")
5051

5152

5253
def abbreviate_role(role):
@@ -124,13 +125,17 @@ def execute_notebook(
124125
if not role:
125126
role = get_execution_role(session)
126127
elif "/" not in role:
127-
account = session.client("sts").get_caller_identity()["Account"]
128-
role = "arn:aws:iam::{}:role/{}".format(account, role)
128+
identity = session.client("sts").get_caller_identity()
129+
account = identity["Account"]
130+
partition = identity["Arn"].split(':')[1]
131+
role = "arn:{}:iam::{}:role/{}".format(partition, account, role)
129132

130133
if "/" not in image:
131134
account = session.client("sts").get_caller_identity()["Account"]
132135
region = session.region_name
133-
image = "{}.dkr.ecr.{}.amazonaws.com/{}:latest".format(account, region, image)
136+
domain = domain_for_region(region)
137+
image = "{}.dkr.ecr.{}.{}/{}:latest".format(
138+
account, region, domain, image)
134139

135140
if notebook == None:
136141
notebook = input_path
@@ -140,7 +145,8 @@ def execute_notebook(
140145
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
141146

142147
job_name = (
143-
("papermill-" + re.sub(r"[^-a-zA-Z0-9]", "-", nb_name))[: 62 - len(timestamp)]
148+
("papermill-" + re.sub(r"[^-a-zA-Z0-9]",
149+
"-", nb_name))[: 62 - len(timestamp)]
144150
+ "-"
145151
+ timestamp
146152
)
@@ -628,8 +634,10 @@ def create_lambda(role=None, session=None):
628634
# time.sleep(30) # wait for eventual consistency, we hope
629635

630636
if "/" not in role:
631-
account = session.client("sts").get_caller_identity()["Account"]
632-
role = "arn:aws:iam::{}:role/{}".format(account, role)
637+
identity = session.client("sts").get_caller_identity()
638+
account = identity["Account"]
639+
partition = identity["Arn"].split(':')[1]
640+
role = "arn:{}:iam::{}:role/{}".format(partition, account, role)
633641

634642
code_bytes = zip_bytes(code_file)
635643

@@ -780,7 +788,9 @@ def proc(extras):
780788
if "/" not in image:
781789
account = session.client("sts").get_caller_identity()["Account"]
782790
region = session.region_name
783-
image = "{}.dkr.ecr.{}.amazonaws.com/{}:latest".format(account, region, image)
791+
domain = domain_for_region(region)
792+
image = "{}.dkr.ecr.{}.{}/{}:latest".format(
793+
account, region, domain, image)
784794

785795
if not role:
786796
try:
@@ -789,8 +799,10 @@ def proc(extras):
789799
role = "BasicExecuteNotebookRole-{}".format(session.region_name)
790800

791801
if "/" not in role:
792-
account = session.client("sts").get_caller_identity()["Account"]
793-
role = "arn:aws:iam::{}:role/{}".format(account, role)
802+
identity = session.client("sts").get_caller_identity()
803+
account = identity["Account"]
804+
partition = identity["Arn"].split(':')[1]
805+
role = "arn:{}:iam::{}:role/{}".format(partition, account, role)
794806

795807
if input_path is None:
796808
input_path = upload_notebook(notebook)
@@ -849,7 +861,7 @@ def schedule(
849861
850862
Creates a CloudWatch Event rule to invoke the installed Lambda either on the provided schedule or in response
851863
to the provided event. \
852-
864+
853865
:meth:`schedule` can upload a local notebook file to run or use one previously uploaded to S3.
854866
To find jobs run by the schedule, see :meth:`list_runs` using the `rule` argument to filter to
855867
a specific rule. To download the results, see :meth:`download_notebook` (or :meth:`download_all`
@@ -905,7 +917,9 @@ def proc(extras):
905917
if "/" not in image:
906918
account = session.client("sts").get_caller_identity()["Account"]
907919
region = session.region_name
908-
image = "{}.dkr.ecr.{}.amazonaws.com/{}:latest".format(account, region, image)
920+
domain = domain_for_region(region)
921+
image = "{}.dkr.ecr.{}.{}/{}:latest".format(
922+
account, region, domain, image)
909923

910924
if not role:
911925
try:
@@ -914,8 +928,10 @@ def proc(extras):
914928
role = "BasicExecuteNotebookRole-{}".format(session.region_name)
915929

916930
if "/" not in role:
917-
account = session.client("sts").get_caller_identity()["Account"]
918-
role = "arn:aws:iam::{}:role/{}".format(account, role)
931+
identity = session.client("sts").get_caller_identity()
932+
account = identity["Account"]
933+
partition = identity["Arn"].split(':')[1]
934+
role = "arn:{}:iam::{}:role/{}".format(partition, account, role)
919935

920936
if input_path is None:
921937
input_path = upload_notebook(notebook)
@@ -945,11 +961,12 @@ def proc(extras):
945961
Description='Rule to run the Jupyter notebook "{}"'.format(notebook),
946962
**kwargs,
947963
)
948-
949-
account = session.client("sts").get_caller_identity()["Account"]
964+
identity = session.client("sts").get_caller_identity()
965+
account = identity["Account"]
966+
partition = identity["Arn"].split(':')[1]
950967
region = session.region_name
951-
target_arn = "arn:aws:lambda:{}:{}:function:{}".format(
952-
region, account, lambda_function_name
968+
target_arn = "arn:{}:lambda:{}:{}:function:{}".format(
969+
partition, region, account, lambda_function_name
953970
)
954971

955972
result = events.put_targets(
@@ -1082,7 +1099,7 @@ def base_image(s):
10821099
return s
10831100

10841101

1085-
role_pat = re.compile(r"arn:aws:iam::([0-9]+):role/(.*)$")
1102+
role_pat = re.compile(r"arn:aws([^:]*):iam::([0-9]+):role/(.*)$")
10861103

10871104

10881105
def base_role(s):

sagemaker_run_notebook/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def sts_regional_endpoint(region):
116116
Returns:
117117
str: AWS STS regional endpoint
118118
"""
119-
domain = _domain_for_region(region)
119+
domain = domain_for_region(region)
120120
return "https://sts.{}.{}".format(region, domain)
121121

122122

123-
def _domain_for_region(region):
123+
def domain_for_region(region):
124124
"""Get the DNS suffix for the given region.
125125
126126
Args:
@@ -129,7 +129,13 @@ def _domain_for_region(region):
129129
Returns:
130130
str: the DNS suffix
131131
"""
132-
return "c2s.ic.gov" if region == "us-iso-east-1" else "amazonaws.com"
132+
if region.startswith("us-iso-"):
133+
return "c2s.ic.gov"
134+
if region.startswith("us-isob-"):
135+
return "sc2s.sgov.gov"
136+
if region.startswith("cn-"):
137+
return "amazonaws.com.cn"
138+
return "amazonaws.com"
133139

134140

135141
def get_execution_role(session):

0 commit comments

Comments
 (0)