Skip to content

Commit 903cb8a

Browse files
authored
Add Owner ID check for bucket with path when prefix is provided (#5146)
* Fix Flake8 Violations * Add Owner ID check for bucket with path when prefix is provided **Description** Previously we called the head_bucket call to ensure the owner ID check, but this doesnt take into consideration cases where the s3 path is provided through the prefix. This change makes sure that director level permissions are supported. **Testing Done** Tested through unit tests, integ tests and manual testing through the installation file. Yes * Address PR comment * Codestyle fixes * Minor fix * Codestyle fixes * Fix Unit tests
1 parent a896bc6 commit 903cb8a

File tree

2 files changed

+53
-5
lines changed

2 files changed

+53
-5
lines changed

src/sagemaker/session.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,6 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
635635

636636
elif self._default_bucket_set_by_sdk:
637637
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)
638-
639638
expected_bucket_owner_id = self.account_id()
640639
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)
641640

@@ -649,9 +648,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket
649648
650649
"""
651650
try:
652-
s3.meta.client.head_bucket(
653-
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
654-
)
651+
if self.default_bucket_prefix:
652+
s3.meta.client.list_objects_v2(
653+
Bucket=bucket_name,
654+
Prefix=self.default_bucket_prefix,
655+
ExpectedBucketOwner=expected_bucket_owner_id,
656+
)
657+
else:
658+
s3.meta.client.head_bucket(
659+
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
660+
)
655661
except ClientError as e:
656662
error_code = e.response["Error"]["Code"]
657663
message = e.response["Error"]["Message"]
@@ -682,7 +688,12 @@ def general_bucket_check_if_user_has_permission(
682688
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
683689
"""
684690
try:
685-
s3.meta.client.head_bucket(Bucket=bucket_name)
691+
if self.default_bucket_prefix:
692+
s3.meta.client.list_objects_v2(
693+
Bucket=bucket_name, Prefix=self.default_bucket_prefix
694+
)
695+
else:
696+
s3.meta.client.head_bucket(Bucket=bucket_name)
686697
except ClientError as e:
687698
error_code = e.response["Error"]["Code"]
688699
message = e.response["Error"]["Message"]

tests/unit/test_default_bucket.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ def sagemaker_session():
3939
return sagemaker_session
4040

4141

42+
@pytest.fixture()
43+
def sagemaker_session_with_bucket_name_and_prefix():
44+
boto_mock = MagicMock(name="boto_session", region_name=REGION)
45+
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
46+
sagemaker_session = sagemaker.Session(
47+
boto_session=boto_mock,
48+
default_bucket="XXXXXXXXXXXXX",
49+
default_bucket_prefix="sample-prefix",
50+
)
51+
sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
52+
return sagemaker_session
53+
54+
4255
def test_default_bucket_s3_create_call(sagemaker_session):
4356
error = ClientError(
4457
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
@@ -96,6 +109,30 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime
96109
assert sagemaker_session._default_bucket is None
97110

98111

112+
def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(
113+
sagemaker_session_with_bucket_name_and_prefix, datetime_obj, caplog
114+
):
115+
with pytest.raises(ClientError):
116+
error = ClientError(
117+
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
118+
operation_name="foo",
119+
)
120+
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource(
121+
"s3"
122+
).meta.client.list_objects_v2.side_effect = error
123+
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket(
124+
name=DEFAULT_BUCKET_NAME
125+
).creation_date = None
126+
sagemaker_session_with_bucket_name_and_prefix.default_bucket()
127+
128+
error_message = "Please try again after adding appropriate access."
129+
assert error_message in caplog.text
130+
assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None
131+
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource(
132+
"s3"
133+
).meta.client.list_objects_v2.assert_called_once()
134+
135+
99136
def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog):
100137
sagemaker_session._default_bucket_name_override = "custom-bucket-override"
101138
error = ClientError(

0 commit comments

Comments
 (0)