Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ public List<RuntimeClientPlugin> getClientPlugins(GenerationContext context) {
.build())
// TODO: Initialize with the provider chain?
.nullable(true)
.initialize(writer -> {
writer.addImport("smithy_aws_core.credentials_resolvers", "CredentialsResolverChain");
writer.write("self.aws_credentials_identity_resolver = aws_credentials_identity_resolver or CredentialsResolverChain(config=self)");
})
.build())
.addConfigProperty(REGION)
.authScheme(new Sigv4AuthScheme())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .chain import CredentialsResolverChain
from .environment import EnvironmentCredentialsResolver
from .imds import IMDSCredentialsResolver
from .static import StaticCredentialsResolver

__all__ = (
"CredentialsResolverChain",
"EnvironmentCredentialsResolver",
"IMDSCredentialsResolver",
"StaticCredentialsResolver",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence

from smithy_core.aio.interfaces.identity import IdentityResolver
from smithy_core.exceptions import SmithyIdentityException
from smithy_core.interfaces.identity import IdentityProperties

from smithy_aws_core.credentials_resolvers.environment import (
EnvironmentCredentialsSource,
)
from smithy_aws_core.credentials_resolvers.imds import IMDSCredentialsSource
from smithy_aws_core.credentials_resolvers.interfaces import (
AwsCredentialsConfig,
CredentialsSource,
)
from smithy_aws_core.identity import AWSCredentialsIdentity, AWSCredentialsResolver

_DEFAULT_SOURCES: Sequence[CredentialsSource] = (
EnvironmentCredentialsSource(),
IMDSCredentialsSource(),
)


class CredentialsResolverChain(
IdentityResolver[AWSCredentialsIdentity, IdentityProperties]
):
"""Resolves AWS Credentials from an ordered list of credentials sources."""

def __init__(
self,
*,
config: AwsCredentialsConfig,
sources: Sequence[CredentialsSource] = _DEFAULT_SOURCES,
):
self._config = config
self._sources: Sequence[CredentialsSource] = sources
self._credentials_resolver: AWSCredentialsResolver | None = None

async def get_identity(
self, *, identity_properties: IdentityProperties
) -> AWSCredentialsIdentity:
if self._credentials_resolver is not None:
return await self._credentials_resolver.get_identity(
identity_properties=identity_properties
)

for source in self._sources:
if source.is_available(config=self._config):
self._credentials_resolver = source.build_resolver(config=self._config)
return await self._credentials_resolver.get_identity(
identity_properties=identity_properties
)

raise SmithyIdentityException(
"None of the configured credentials sources were able to resolve credentials."
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from smithy_core.exceptions import SmithyIdentityException
from smithy_core.interfaces.identity import IdentityProperties

from ..identity import AWSCredentialsIdentity
from smithy_aws_core.credentials_resolvers.interfaces import (
AwsCredentialsConfig,
CredentialsSource,
)

from ..identity import AWSCredentialsIdentity, AWSCredentialsResolver


class EnvironmentCredentialsResolver(
Expand Down Expand Up @@ -41,3 +46,13 @@ async def get_identity(
)

return self._credentials


class EnvironmentCredentialsSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
return (
"AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ
)

def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
return EnvironmentCredentialsResolver()
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
from smithy_http.aio import HTTPRequest
from smithy_http.aio.interfaces import HTTPClient

from smithy_aws_core.credentials_resolvers.interfaces import (
AwsCredentialsConfig,
CredentialsSource,
)

from .. import __version__
from ..identity import AWSCredentialsIdentity
from ..identity import AWSCredentialsIdentity, AWSCredentialsResolver

_USER_AGENT_FIELD = Field(
name="User-Agent",
Expand Down Expand Up @@ -235,3 +240,14 @@ async def get_identity(
account_id=account_id,
)
return self._credentials


class IMDSCredentialsSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
# IMDS credentials should always be the last in the chain
# We cannot check if they're available without actually making a call
return True

def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
# TODO: Configure lower number of retries/lower timeout
return IMDSCredentialsResolver(http_client=config.http_client)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Protocol

from smithy_http.aio.interfaces import HTTPClient

from smithy_aws_core.identity import AWSCredentialsResolver


class AwsCredentialsConfig(Protocol):
"""Configuration required for resolving credentials."""

http_client: HTTPClient


class CredentialsSource(Protocol):
def is_available(self, config: AwsCredentialsConfig) -> bool:
"""Returns True if credentials are available from this source."""
...

def build_resolver(self, config: AwsCredentialsConfig) -> AWSCredentialsResolver:
"""Builds a credentials resolver for the given configuration."""
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from dataclasses import dataclass
from unittest.mock import Mock

import pytest
from smithy_aws_core.credentials_resolvers import (
CredentialsResolverChain,
IMDSCredentialsResolver,
StaticCredentialsResolver,
)
from smithy_aws_core.credentials_resolvers.environment import (
EnvironmentCredentialsSource,
)
from smithy_aws_core.credentials_resolvers.interfaces import (
AwsCredentialsConfig,
CredentialsSource,
)
from smithy_aws_core.identity import AWSCredentialsIdentity, AWSCredentialsResolver
from smithy_core.exceptions import SmithyIdentityException
from smithy_core.interfaces.identity import IdentityProperties
from smithy_http.aio.interfaces import HTTPClient


@dataclass
class Config:
http_client: HTTPClient

def __init__(self):
self.http_client = Mock(spec=HTTPClient) # type: ignore


async def test_no_sources_resolve():
resolver_chain = CredentialsResolverChain(sources=[], config=Config())
with pytest.raises(SmithyIdentityException):
await resolver_chain.get_identity(identity_properties=IdentityProperties())


async def test_env_credentials_resolver_not_set(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
resolver_chain = CredentialsResolverChain(
sources=[EnvironmentCredentialsSource()], config=Config()
)

with pytest.raises(SmithyIdentityException):
await resolver_chain.get_identity(identity_properties=IdentityProperties())


async def test_env_credentials_resolver_partial(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid")
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
resolver_chain = CredentialsResolverChain(
sources=[EnvironmentCredentialsSource()], config=Config()
)

with pytest.raises(SmithyIdentityException):
await resolver_chain.get_identity(identity_properties=IdentityProperties())


async def test_default_sources_env_credentials_resolver_success(
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret")
resolver_chain = CredentialsResolverChain(config=Config())

credentials = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
assert credentials.access_key_id == "akid"
assert credentials.secret_access_key == "secret"


async def test_default_sources_imds_resolver_success(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)

async def mock_imds_get_identity(
self: IMDSCredentialsResolver, *, identity_properties: IdentityProperties
) -> AWSCredentialsIdentity:
return AWSCredentialsIdentity(
access_key_id="akid",
secret_access_key="secret",
)

monkeypatch.setattr(
"smithy_aws_core.credentials_resolvers.IMDSCredentialsResolver.get_identity",
mock_imds_get_identity,
)

resolver_chain = CredentialsResolverChain(config=Config())

credentials = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
assert credentials.access_key_id == "akid"
assert credentials.secret_access_key == "secret"


async def test_multiple_sources_one_valid():
class FailingSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
return False

def build_resolver(
self, config: AwsCredentialsConfig
) -> AWSCredentialsResolver:
raise RuntimeError("Should not be called")

static_credentials = AWSCredentialsIdentity(
access_key_id="valid_akid", secret_access_key="valid_secret"
)
static_resolver = StaticCredentialsResolver(credentials=static_credentials)

class ValidSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
return True

def build_resolver(
self, config: AwsCredentialsConfig
) -> AWSCredentialsResolver:
return static_resolver

resolver_chain = CredentialsResolverChain(
sources=[FailingSource(), ValidSource()], config=Config()
)

credentials = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
assert credentials.access_key_id == "valid_akid"
assert credentials.secret_access_key == "valid_secret"


async def test_cached_resolver_used(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("AWS_ACCESS_KEY_ID", "cached_akid")
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "cached_secret")
resolver_chain = CredentialsResolverChain(
sources=[EnvironmentCredentialsSource()], config=Config()
)

credentials1 = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
credentials2 = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)

assert credentials1.access_key_id == credentials2.access_key_id == "cached_akid"
assert (
credentials1.secret_access_key
== credentials2.secret_access_key
== "cached_secret"
)


async def test_custom_sources_with_static_credentials():
static_credentials = AWSCredentialsIdentity(
access_key_id="static_akid",
secret_access_key="static_secret",
)
static_resolver = StaticCredentialsResolver(credentials=static_credentials)

class TestStaticSource(CredentialsSource):
def is_available(self, config: AwsCredentialsConfig) -> bool:
return True

def build_resolver(
self, config: AwsCredentialsConfig
) -> AWSCredentialsResolver:
return static_resolver

resolver_chain = CredentialsResolverChain(
sources=[TestStaticSource()],
config=Config(), # type: ignore
)

credentials = await resolver_chain.get_identity(
identity_properties=IdentityProperties()
)
assert credentials.access_key_id == "static_akid"
assert credentials.secret_access_key == "static_secret"
4 changes: 3 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.