Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/3731.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd add it's a backport in the PR title

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

provider: Fixes STS region resolution when using cross-region authentication
```
24 changes: 21 additions & 3 deletions .github/workflows/acceptance-tests-runner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,23 @@ jobs:
needs: [ change-detection, get-provider-version ]
if: ${{ needs.change-detection.outputs.assume_role == 'true' || inputs.test_group == 'assume_role' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
# Secret and STS Endpoint in same region
- name: same-region-us-east-1
aws_region: US_EAST_1
sts_endpoint: https://sts.us-east-1.amazonaws.com/
# Secret and STS Endpoint in different regions(Cross-region)
- name: cross-sts-us-east-1-secret-eu-north-1
aws_region: EU_NORTH_1
sts_endpoint: https://sts.us-east-1.amazonaws.com/
# Global STS endpoint (signs as us-east-1), secrets in eu-west-1
- name: global-sts-secret-eu-west-1
aws_region: EU_WEST_1
sts_endpoint: https://sts.amazonaws.com
name: assume_role – ${{ matrix.name }}
permissions: {}
steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8
Expand All @@ -478,19 +495,20 @@ jobs:
AWS_ACCESS_KEY_ID: ${{ secrets.aws_access_key_id }}
ASSUME_ROLE_ARN: ${{ vars.ASSUME_ROLE_ARN }}
run: bash ./scripts/generate-credentials-with-sts-assume-role.sh
- name: Acceptance Tests
- name: Acceptance Tests (matrix)
env:
MONGODB_ATLAS_PUBLIC_KEY: ""
MONGODB_ATLAS_PRIVATE_KEY: ""
ASSUME_ROLE_ARN: ${{ vars.ASSUME_ROLE_ARN }}
AWS_REGION: ${{ vars.AWS_REGION }}
STS_ENDPOINT: ${{ vars.STS_ENDPOINT }}
AWS_REGION: ${{ matrix.aws_region }}
STS_ENDPOINT: ${{ matrix.sts_endpoint }}
SECRET_NAME: ${{ inputs.aws_secret_name }}
AWS_ACCESS_KEY_ID: ${{ steps.sts-assume-role.outputs.aws_access_key_id }}
AWS_SECRET_ACCESS_KEY: ${{ steps.sts-assume-role.outputs.aws_secret_access_key }}
AWS_SESSION_TOKEN: ${{ steps.sts-assume-role.outputs.AWS_SESSION_TOKEN }}
MONGODB_ATLAS_LAST_VERSION: ${{ needs.get-provider-version.outputs.provider_version }}
ACCTEST_PACKAGES: ./internal/provider
ACCTEST_REGEX_RUN: ^TestAccSTSAssumeRole_basic$
run: make testacc

autogen:
Expand Down
80 changes: 56 additions & 24 deletions internal/provider/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"fmt"
"log"
"net/url"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
Expand All @@ -12,48 +14,39 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/sts"

"github.com/mongodb/terraform-provider-mongodbatlas/internal/config"
)

const (
endPointSTSDefault = "https://sts.amazonaws.com"
endPointSTSHostnameDefault = "sts.amazonaws.com"
DefaultRegionSTS = "us-east-1"
minSegmentsForSTSRegionalHost = 4
)

func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID, awsSecretAccessKey, awsSessionToken, endpoint string) (config.Config, error) {
ep, err := endpoints.GetSTSRegionalEndpoint("regional")
if err != nil {
log.Printf("GetSTSRegionalEndpoint error: %s", err)
return *cfg, err
}

defaultResolver := endpoints.DefaultResolver()
stsCustResolverFn := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if service == endpoints.StsServiceID {
if endpoint == "" {
return endpoints.ResolvedEndpoint{
URL: endPointSTSDefault,
SigningRegion: region,
}, nil
stsCustResolverFn := func(service, _ string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if service == sts.EndpointsID {
resolved, err := ResolveSTSEndpoint(endpoint, region)
if err != nil {
return endpoints.ResolvedEndpoint{}, err
}
return endpoints.ResolvedEndpoint{
URL: endpoint,
SigningRegion: region,
}, nil
return resolved, nil
}

return defaultResolver.EndpointFor(service, region, optFns...)
}

sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(region),
Credentials: credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, awsSessionToken),
STSRegionalEndpoint: ep,
EndpointResolver: endpoints.ResolverFunc(stsCustResolverFn),
Region: aws.String(region),
Credentials: credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, awsSessionToken),
EndpointResolver: endpoints.ResolverFunc(stsCustResolverFn),
}))

creds := stscreds.NewCredentials(sess, cfg.AssumeRole.RoleARN)

_, err = sess.Config.Credentials.Get()
_, err := sess.Config.Credentials.Get()
if err != nil {
log.Printf("Session get credentials error: %s", err)
return *cfg, err
Expand Down Expand Up @@ -87,6 +80,45 @@ func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID,
return *cfg, nil
}

func DeriveSTSRegionFromEndpoint(ep string) string {
if ep == "" {
return ""
}
u, err := url.Parse(ep)
if err != nil {
return DefaultRegionSTS
}
host := u.Hostname() // valid values: sts.us-west-2.amazonaws.com or sts.amazonaws.com

if host == endPointSTSHostnameDefault {
return DefaultRegionSTS
}

parts := strings.Split(host, ".")
if len(parts) >= minSegmentsForSTSRegionalHost && parts[0] == "sts" {
return parts[1]
}
return DefaultRegionSTS
}

func ResolveSTSEndpoint(stsEndpoint, secretsRegion string) (endpoints.ResolvedEndpoint, error) {
ep := stsEndpoint
if ep == "" {
r := secretsRegion
if r == "" {
r = DefaultRegionSTS
}
ep = fmt.Sprintf("https://sts.%s.amazonaws.com/", r)
}

signingRegion := DeriveSTSRegionFromEndpoint(ep)

return endpoints.ResolvedEndpoint{
URL: ep,
SigningRegion: signingRegion,
}, nil
}

func secretsManagerGetSecretValue(sess *session.Session, creds *aws.Config, secret string) (string, error) {
svc := secretsmanager.New(sess, creds)
input := &secretsmanager.GetSecretValueInput{
Expand Down
94 changes: 94 additions & 0 deletions internal/provider/credentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package provider_test

import (
"testing"

"github.com/mongodb/terraform-provider-mongodbatlas/internal/provider"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_deriveSTSRegionFromEndpoint(t *testing.T) {
testCases := map[string]struct {
input string
expected string
}{
"empty endpoint": {
input: "",
expected: "",
},
"global endpoint": {
input: "https://sts.amazonaws.com",
expected: provider.DefaultRegionSTS,
},
"regional": {
input: "https://sts.us-east-1.amazonaws.com/",
expected: "us-east-1",
},
"regional eu-north-1": {
input: "https://sts.eu-north-1.amazonaws.com/",
expected: "eu-north-1",
},
"malformed url": {
input: "://not-a-url",
expected: provider.DefaultRegionSTS,
},
"unexpected host shape": {
input: "https://sts.something-weird",
expected: provider.DefaultRegionSTS,
},
}

for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
t.Parallel()
got := provider.DeriveSTSRegionFromEndpoint(tc.input)
if got != tc.expected {
t.Fatalf("deriveSTSRegionFromEndpoint(%q) = %q; want %q", tc.input, got, tc.expected)
}
})
}
}

func Test_resolveSTSEndpoint(t *testing.T) {
testCases := map[string]struct {
stsEndpoint string
secretsRegion string
expectedURL string
expectedSign string
}{
"explicit regional endpoint": {
stsEndpoint: "https://sts.eu-north-1.amazonaws.com/",
secretsRegion: "us-east-1",
expectedURL: "https://sts.eu-north-1.amazonaws.com/",
expectedSign: "eu-north-1",
},
"global endpoint - us-east-1 signing": {
stsEndpoint: "https://sts.amazonaws.com",
secretsRegion: "eu-west-1",
expectedURL: "https://sts.amazonaws.com",
expectedSign: provider.DefaultRegionSTS,
},
"no endpoint - uses secrets region": {
stsEndpoint: "",
secretsRegion: "us-west-2",
expectedURL: "https://sts.us-west-2.amazonaws.com/",
expectedSign: "us-west-2",
},
"no endpoint and empty region": {
stsEndpoint: "",
secretsRegion: "",
expectedURL: "https://sts.us-east-1.amazonaws.com/",
expectedSign: provider.DefaultRegionSTS,
},
}

for testName, tc := range testCases {
t.Run(testName, func(t *testing.T) {
ep, err := provider.ResolveSTSEndpoint(tc.stsEndpoint, tc.secretsRegion)
require.NoError(t, err)
assert.Equal(t, tc.expectedURL, ep.URL)
assert.Equal(t, tc.expectedSign, ep.SigningRegion)
})
}
}