Skip to content

Commit

Permalink
Fix firehose / s3 client creation of clients with role (localstack#8159)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfangl authored Apr 19, 2023
1 parent 8b63c2c commit 13a1419
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
11 changes: 10 additions & 1 deletion localstack/services/firehose/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
opensearch_domain_name,
s3_bucket_name,
)
from localstack.utils.aws.client_types import ServicePrincipal
from localstack.utils.common import (
TIMESTAMP_FORMAT_MICROS,
first_char_to_lower,
Expand Down Expand Up @@ -719,7 +720,15 @@ def _put_records_to_s3_bucket(
bucket = s3_bucket_name(s3_destination_description["BucketARN"])
prefix = s3_destination_description.get("Prefix", "")

s3 = connect_to().s3.request_metadata(source_arn=stream_name, service_principal="firehose")
if role_arn := s3_destination_description.get("RoleARN"):
factory = connect_to.with_assumed_role(
role_arn=role_arn, service_principal=ServicePrincipal.firehose
)
else:
factory = connect_to()
s3 = factory.s3.request_metadata(
source_arn=stream_name, service_principal=ServicePrincipal.firehose
)
batched_data = b"".join([base64.b64decode(r.get("Data") or r.get("data")) for r in records])

obj_path = self._get_s3_object_path(stream_name, prefix)
Expand Down
1 change: 1 addition & 0 deletions localstack/services/sns/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class SnsSubscription(TypedDict):
FilterPolicyScope: Literal["MessageAttributes", "MessageBody"]
RawMessageDelivery: Literal["true", "false"]
ConfirmationWasAuthenticated: Literal["true", "false"]
SubscriptionRoleArn: Optional[str]


class SnsStore(BaseStore):
Expand Down
11 changes: 9 additions & 2 deletions localstack/services/sns/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
sqs_queue_url_for_arn,
)
from localstack.utils.aws.aws_responses import create_sqs_system_attributes
from localstack.utils.aws.client_types import ServicePrincipal
from localstack.utils.aws.dead_letter_queue import sns_error_to_dead_letter_queue
from localstack.utils.cloudwatch.cloudwatch_util import store_cloudwatch_logs
from localstack.utils.objects import not_none_or
Expand Down Expand Up @@ -582,8 +583,14 @@ def _publish(self, context: SnsPublishContext, subscriber: SnsSubscription):
message_body = self.prepare_message(context.message, subscriber)
try:
region = extract_region_from_arn(subscriber["Endpoint"])
firehose_client = connect_to(region_name=region).firehose.request_metadata(
source_arn=subscriber["TopicArn"], service_principal="sns"
if role_arn := subscriber.get("SubscriptionRoleArn"):
factory = connect_to.with_assumed_role(
role_arn=role_arn, service_principal=ServicePrincipal.sns, region_name=region
)
else:
factory = connect_to(region_name=region)
firehose_client = factory.firehose.request_metadata(
source_arn=subscriber["TopicArn"], service_principal=ServicePrincipal.sns
)
endpoint = subscriber["Endpoint"]
if endpoint:
Expand Down
1 change: 1 addition & 0 deletions localstack/utils/aws/client_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,5 +228,6 @@ class ServicePrincipal(str):

awslambda = "lambda"
apigateway = "apigateway"
firehose = "firehose"
sqs = "sqs"
sns = "sns"

0 comments on commit 13a1419

Please sign in to comment.