Skip to content

Commit

Permalink
Lambda: Use proper account ID and region during ARN construction (loc…
Browse files Browse the repository at this point in the history
  • Loading branch information
viren-nadkarni authored Oct 25, 2023
1 parent f4a95d2 commit a975ce2
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 136 deletions.
140 changes: 98 additions & 42 deletions localstack/services/lambda_/lambda_api.py

Large diffs are not rendered by default.

27 changes: 16 additions & 11 deletions localstack/utils/aws/arns.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from localstack.aws.accounts import DEFAULT_AWS_ACCOUNT_ID, get_aws_account_id
from localstack.aws.connect import connect_to
from localstack.utils.aws.aws_stack import get_region, get_valid_regions
from localstack.utils.aws.aws_stack import get_region

# set up logger
LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -169,9 +169,9 @@ def event_bus_arn(bus_name, account_id=None, region_name=None):
return _resource_arn(bus_name, pattern, account_id=account_id, region_name=region_name)


def lambda_function_arn(function_name, account_id=None, region_name=None):
def lambda_function_arn(function_name: str, account_id: str, region_name: str) -> str:
return lambda_function_or_layer_arn(
"function", function_name, account_id=account_id, region_name=region_name
"function", function_name, version=None, account_id=account_id, region_name=region_name
)


Expand All @@ -182,13 +182,17 @@ def lambda_layer_arn(layer_name, version=None, region_name=None, account_id=None


def lambda_function_or_layer_arn(
type, entity_name, version=None, account_id=None, region_name=None
):
type: str,
entity_name: str,
version: Optional[str],
account_id: str,
region_name: str,
) -> str:
pattern = "arn:([a-z-]+):lambda:.*:.*:(function|layer):.*"
if re.match(pattern, entity_name):
return entity_name
if ":" in entity_name:
client = connect_to().lambda_
client = connect_to(aws_access_key_id=account_id, region_name=region_name).lambda_
entity_name, _, alias = entity_name.rpartition(":")
try:
alias_response = client.get_alias(FunctionName=entity_name, Name=alias)
Expand All @@ -199,8 +203,6 @@ def lambda_function_or_layer_arn(
LOG.info(f"{msg}: {e}")
raise Exception(msg)

account_id = account_id or get_aws_account_id()
region_name = region_name or get_region()
result = f"arn:aws:lambda:{region_name}:{account_id}:{type}:{entity_name}"
if version:
result = f"{result}:{version}"
Expand Down Expand Up @@ -232,9 +234,12 @@ def fix_arn(arn):
"""Function that attempts to "canonicalize" the given ARN. This includes converting
resource names to ARNs, replacing incorrect regions, account IDs, etc."""
if arn.startswith("arn:aws:lambda"):
parts = arn.split(":")
region = parts[3] if parts[3] in get_valid_regions() else get_region()
return lambda_function_arn(lambda_function_name(arn), region_name=region)
arn_data = parse_arn(arn)
return lambda_function_arn(
lambda_function_name(arn),
account_id=arn_data["account"],
region_name=arn_data["region"],
)
LOG.warning("Unable to fix/canonicalize ARN: %s", arn)
return arn

Expand Down
5 changes: 2 additions & 3 deletions localstack/utils/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,9 @@ def create_lambda_api_gateway_integration(

# create Lambda
zip_file = create_lambda_archive(handler_file, get_content=True, runtime=runtime)
create_lambda_function(
func_arn = create_lambda_function(
func_name=func_name, zip_file=zip_file, runtime=runtime, client=lambda_client
)
func_arn = arns.lambda_function_arn(func_name)
)["CreateFunctionResponse"]["FunctionArn"]
target_arn = arns.apigateway_invocations_arn(func_arn, TEST_AWS_REGION_NAME)

# connect API GW to Lambda
Expand Down
41 changes: 16 additions & 25 deletions tests/aws/services/apigateway/test_apigateway_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,11 @@ def test_api_gateway_lambda_integration(
function input event.
"""
fn_name = f"test-{short_uid()}"
create_lambda_function(
lambda_arn = create_lambda_function(
func_name=fn_name,
handler_file=TEST_LAMBDA_AWS_PROXY,
runtime=Runtime.python3_9,
)
lambda_arn = arns.lambda_function_arn(fn_name, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME)
)["CreateFunctionResponse"]["FunctionArn"]

api_id, _, root = create_rest_apigw(name="aws lambda api")
resource_id, _ = create_rest_resource(
Expand Down Expand Up @@ -571,10 +570,9 @@ def test_api_gateway_lambda_asynchronous_invocation(
rest_api_id, _, _ = create_rest_apigw(name=api_gateway_name)

fn_name = f"test-{short_uid()}"
create_lambda_function(
lambda_arn = create_lambda_function(
handler_file=TEST_LAMBDA_NODEJS, func_name=fn_name, runtime=Runtime.nodejs16_x
)
lambda_arn = arns.lambda_function_arn(fn_name, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME)
)["CreateFunctionResponse"]["FunctionArn"]

spec_file = load_file(TEST_IMPORT_REST_API_ASYNC_LAMBDA)
spec_file = spec_file.replace("${lambda_invocation_arn}", lambda_arn)
Expand Down Expand Up @@ -653,16 +651,13 @@ def test_malformed_response_apigw_invocation(self, create_lambda_function, aws_c
lambda_resource = "/api/v1/{proxy+}"
lambda_path = "/api/v1/hello/world"

create_lambda_function(
lambda_uri = create_lambda_function(
func_name=lambda_name,
zip_file=testutil.create_zip_file(TEST_LAMBDA_NODEJS_APIGW_502, get_content=True),
runtime=Runtime.nodejs16_x,
handler="apigw_502.handler",
)
)["CreateFunctionResponse"]["FunctionArn"]

lambda_uri = arns.lambda_function_arn(
lambda_name, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME
)
target_uri = f"arn:aws:apigateway:{TEST_AWS_REGION_NAME}:lambda:path/2015-03-31/functions/{lambda_uri}/invocations"
result = testutil.connect_api_gateway_to_http_with_lambda_proxy(
"test_gateway",
Expand Down Expand Up @@ -1017,14 +1012,11 @@ def test_apigateway_with_step_function_integration(

# create lambda
fn_name = f"lambda-sfn-apigw-{short_uid()}"
create_lambda_function(
lambda_arn = create_lambda_function(
handler_file=TEST_LAMBDA_PYTHON_ECHO,
func_name=fn_name,
runtime=Runtime.python3_9,
)
lambda_arn = arns.lambda_function_arn(
function_name=fn_name, account_id=aws_account_id, region_name=region_name
)
)["CreateFunctionResponse"]["FunctionArn"]

# create state machine and permissions for step function to invoke lambda
role_name = f"sfn_role-{short_uid()}"
Expand Down Expand Up @@ -1383,10 +1375,9 @@ def test_apigw_test_invoke_method_api(
):
# create test Lambda
fn_name = f"test-{short_uid()}"
create_lambda_function(
lambda_arn_1 = create_lambda_function(
handler_file=TEST_LAMBDA_NODEJS, func_name=fn_name, runtime=Runtime.nodejs16_x
)
lambda_arn_1 = arns.lambda_function_arn(fn_name, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME)
)["CreateFunctionResponse"]["FunctionArn"]

# create REST API and test resource
rest_api_id, _, _ = create_rest_apigw(name="test", description="test")
Expand Down Expand Up @@ -1725,23 +1716,24 @@ def test_rest_api_multi_region(
lambda_name = f"lambda-{short_uid()}"
lambda_eu_west_1_client = aws_client_factory(region_name="eu-west-1").lambda_
lambda_us_west_1_client = aws_client_factory(region_name="us-west-1").lambda_
testutil.create_lambda_function(
lambda_eu_arn = testutil.create_lambda_function(
handler_file=TEST_LAMBDA_NODEJS,
func_name=lambda_name,
runtime=Runtime.nodejs16_x,
region_name="eu-west-1",
client=lambda_eu_west_1_client,
)
testutil.create_lambda_function(
)["CreateFunctionResponse"]["FunctionArn"]

lambda_us_arn = testutil.create_lambda_function(
handler_file=TEST_LAMBDA_NODEJS,
func_name=lambda_name,
runtime=Runtime.nodejs16_x,
region_name="us-west-1",
client=lambda_us_west_1_client,
)
)["CreateFunctionResponse"]["FunctionArn"]

lambda_eu_west_1_client.get_waiter("function_active_v2").wait(FunctionName=lambda_name)
lambda_us_west_1_client.get_waiter("function_active_v2").wait(FunctionName=lambda_name)
lambda_eu_arn = arns.lambda_function_arn(lambda_name, TEST_AWS_ACCOUNT_ID, "eu-west-1")
uri_eu = arns.apigateway_invocations_arn(lambda_eu_arn, region_name="eu-west-1")

integration_uri, _ = create_rest_api_integration(
Expand All @@ -1754,7 +1746,6 @@ def test_rest_api_multi_region(
uri=uri_eu,
)

lambda_us_arn = arns.lambda_function_arn(lambda_name, TEST_AWS_ACCOUNT_ID, "us-west-1")
uri_us = arns.apigateway_invocations_arn(lambda_us_arn, region_name="us-west-1")

integration_uri, _ = create_rest_api_integration(
Expand Down
4 changes: 3 additions & 1 deletion tests/aws/services/cloudformation/resources/test_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,9 @@ def test_cfn_handle_serverless_api_resource(self, deploy_cfn_template, aws_clien
resource = rs["items"][0]

uri = resource["resourceMethods"]["GET"]["methodIntegration"]["uri"]
lambda_arn = arns.lambda_function_arn(lambda_func_names[0]) # TODO
lambda_arn = arns.lambda_function_arn(
lambda_func_names[0], account_id=TEST_AWS_ACCOUNT_ID, region_name=TEST_AWS_REGION_NAME
)
assert lambda_arn in uri

# TODO: refactor
Expand Down
7 changes: 4 additions & 3 deletions tests/aws/services/firehose/test_firehose.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from localstack import config
from localstack.testing.pytest import markers
from localstack.utils.aws import arns
from localstack.utils.aws.arns import lambda_function_arn
from localstack.utils.strings import short_uid, to_bytes, to_str
from localstack.utils.sync import poll_condition, retry

Expand Down Expand Up @@ -42,7 +41,9 @@ def test_firehose_http(
if lambda_processor_enabled:
# create processor func
func_name = f"proc-{short_uid()}"
create_lambda_function(handler_file=PROCESSOR_LAMBDA, func_name=func_name)
func_arn = create_lambda_function(handler_file=PROCESSOR_LAMBDA, func_name=func_name)[
"CreateFunctionResponse"
]["FunctionArn"]

# define firehose configs
# records = []
Expand Down Expand Up @@ -70,7 +71,7 @@ def test_firehose_http(
"Parameters": [
{
"ParameterName": "LambdaArn",
"ParameterValue": lambda_function_arn(func_name),
"ParameterValue": func_arn,
}
],
}
Expand Down
6 changes: 5 additions & 1 deletion tests/aws/services/lambda_/test_lambda_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from localstack.aws.accounts import get_aws_account_id
from localstack.aws.api.lambda_ import Runtime
from localstack.constants import TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME
from localstack.services.lambda_ import lambda_api
from localstack.services.lambda_.lambda_api import (
LAMBDA_TEST_ROLE,
Expand Down Expand Up @@ -129,7 +130,10 @@ def test_add_lambda_multiple_permission(self, create_lambda_function, aws_client
statements = versions[0]["Document"]["Statement"]
for i in range(len(statement_ids)):
assert action == statements[i]["Action"]
assert lambda_api.func_arn(function_name) == statements[i]["Resource"]
assert (
lambda_api.func_arn(TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME, function_name)
== statements[i]["Resource"]
)
assert principal == statements[i]["Principal"]["Service"]
assert (
arns.s3_bucket_arn("test-bucket")
Expand Down
5 changes: 3 additions & 2 deletions tests/aws/services/lambda_/test_lambda_whitebox.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import localstack.services.lambda_.lambda_api
from localstack import config
from localstack.constants import TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME
from localstack.services.lambda_ import lambda_api, lambda_executors
from localstack.services.lambda_.lambda_api import do_set_function_code, use_docker
from localstack.services.lambda_.lambda_utils import LAMBDA_RUNTIME_PYTHON39
Expand Down Expand Up @@ -247,7 +248,7 @@ def test_code_updated_on_redeployment(self, aws_client):
def test_prime_and_destroy_containers(self, aws_client):
executor = lambda_api.LAMBDA_EXECUTOR
func_name = f"test_prime_and_destroy_containers_{short_uid()}"
func_arn = lambda_api.func_arn(func_name)
func_arn = lambda_api.func_arn(TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME, func_name)

# make sure existing containers are gone
executor.cleanup()
Expand Down Expand Up @@ -318,7 +319,7 @@ def test_prime_and_destroy_containers(self, aws_client):
def test_destroy_idle_containers(self, aws_client):
executor = lambda_api.LAMBDA_EXECUTOR
func_name = "test_destroy_idle_containers"
func_arn = lambda_api.func_arn(func_name)
func_arn = lambda_api.func_arn(TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME, func_name)

# make sure existing containers are gone
executor.destroy_existing_docker_containers()
Expand Down
8 changes: 3 additions & 5 deletions tests/aws/services/logs/test_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,11 @@ def test_put_subscription_filter_lambda(
)

test_lambda_name = f"test-lambda-function-{short_uid()}"
create_lambda_function(
func_arn = create_lambda_function(
handler_file=TEST_LAMBDA_PYTHON_ECHO,
func_name=test_lambda_name,
runtime=Runtime.python3_9,
)
)["CreateFunctionResponse"]["FunctionArn"]
aws_client.lambda_.invoke(FunctionName=test_lambda_name, Payload=b"{}")
# get account-id to set the correct policy
account_id = aws_client.sts.get_caller_identity()["Account"]
Expand All @@ -301,9 +301,7 @@ def test_put_subscription_filter_lambda(
logGroupName=logs_log_group,
filterName="test",
filterPattern="",
destinationArn=arns.lambda_function_arn(
test_lambda_name, account_id=account_id, region_name=config.AWS_REGION_US_EAST_1
),
destinationArn=func_arn,
)
snapshot.match("put_subscription_filter", result)

Expand Down
4 changes: 2 additions & 2 deletions tests/aws/test_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from localstack.constants import TEST_AWS_REGION_NAME
from localstack.constants import TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME
from localstack.testing.pytest import markers
from localstack.utils.aws import arns
from localstack.utils.common import retry, run
Expand Down Expand Up @@ -182,7 +182,7 @@ def test_apigateway_deployed(self, aws_client, setup_and_teardown):
assert method in proxy_resource["resourceMethods"]
resource_method = proxy_resource["resourceMethods"][method]
assert (
arns.lambda_function_arn(function_name)
arns.lambda_function_arn(function_name, TEST_AWS_ACCOUNT_ID, TEST_AWS_REGION_NAME)
in resource_method["methodIntegration"]["uri"]
)

Expand Down
Loading

0 comments on commit a975ce2

Please sign in to comment.