Skip to content

Commit 750c8a0

Browse files
Bug fix for OAuth m2m scopes (#178)
Updated m2m authenticator to use "all-apis" scope. Added a new constructor function for m2m authenticator that allows client to pass in additional scopes. Signed-off-by: Raymond Cypher <[email protected]>
2 parents 714e264 + 91dced9 commit 750c8a0

File tree

5 files changed

+65
-11
lines changed

5 files changed

+65
-11
lines changed

auth/oauth/m2m/m2m.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ import (
1717
)
1818

1919
func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator {
20-
scopes := oauth.GetScopes(hostName, []string{})
20+
return NewAuthenticatorWithScopes(clientID, clientSecret, hostName, []string{})
21+
}
22+
23+
func NewAuthenticatorWithScopes(clientID, clientSecret, hostName string, scopes []string) auth.Authenticator {
24+
scopes = GetScopes(hostName, scopes)
2125
return &authClient{
2226
clientID: clientID,
2327
clientSecret: clientSecret,
@@ -89,3 +93,11 @@ func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, sc
8993

9094
return config, nil
9195
}
96+
97+
func GetScopes(hostName string, scopes []string) []string {
98+
if !oauth.HasScope(scopes, "all-apis") {
99+
scopes = append(scopes, "all-apis")
100+
}
101+
102+
return scopes
103+
}

auth/oauth/m2m/m2m_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package m2m
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestM2MScopes(t *testing.T) {
10+
t.Run("default should be [all-apis]", func(t *testing.T) {
11+
auth := NewAuthenticator("id", "secret", "staging.cloud.company.com").(*authClient)
12+
assert.Equal(t, "id", auth.clientID)
13+
assert.Equal(t, "secret", auth.clientSecret)
14+
assert.Equal(t, []string{"all-apis"}, auth.scopes)
15+
16+
auth = NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", nil).(*authClient)
17+
assert.Equal(t, "id", auth.clientID)
18+
assert.Equal(t, "secret", auth.clientSecret)
19+
assert.Equal(t, []string{"all-apis"}, auth.scopes)
20+
21+
auth = NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{}).(*authClient)
22+
assert.Equal(t, "id", auth.clientID)
23+
assert.Equal(t, "secret", auth.clientSecret)
24+
assert.Equal(t, []string{"all-apis"}, auth.scopes)
25+
})
26+
27+
t.Run("should add all-apis to passed scopes", func(t *testing.T) {
28+
auth := NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{"my-scope"}).(*authClient)
29+
assert.Equal(t, "id", auth.clientID)
30+
assert.Equal(t, "secret", auth.clientSecret)
31+
assert.Equal(t, []string{"my-scope", "all-apis"}, auth.scopes)
32+
})
33+
34+
t.Run("should not add all-apis if already in passed scopes", func(t *testing.T) {
35+
auth := NewAuthenticatorWithScopes("id", "secret", "staging.cloud.company.com", []string{"all-apis", "my-scope"}).(*authClient)
36+
assert.Equal(t, "id", auth.clientID)
37+
assert.Equal(t, "secret", auth.clientSecret)
38+
assert.Equal(t, []string{"all-apis", "my-scope"}, auth.scopes)
39+
})
40+
}

auth/oauth/oauth.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,27 @@ func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error)
4545

4646
func GetScopes(hostName string, scopes []string) []string {
4747
for _, s := range []string{oidc.ScopeOfflineAccess} {
48-
if !hasScope(scopes, s) {
48+
if !HasScope(scopes, s) {
4949
scopes = append(scopes, s)
5050
}
5151
}
5252

5353
cloudType := InferCloudFromHost(hostName)
5454
if cloudType == Azure {
5555
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId)
56-
if !hasScope(scopes, userImpersonationScope) {
56+
if !HasScope(scopes, userImpersonationScope) {
5757
scopes = append(scopes, userImpersonationScope)
5858
}
5959
} else {
60-
if !hasScope(scopes, "sql") {
60+
if !HasScope(scopes, "sql") {
6161
scopes = append(scopes, "sql")
6262
}
6363
}
6464

6565
return scopes
6666
}
6767

68-
func hasScope(scopes []string, scope string) bool {
68+
func HasScope(scopes []string, scope string) bool {
6969
for _, s := range scopes {
7070
if s == scope {
7171
return true

auth/oauth/u2m/authenticator.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ import (
2525
)
2626

2727
const (
28-
azureClientId = "96eecda7-19ea-49cc-abb5-240097d554f5"
29-
azureRedirctURL = "localhost:8030"
28+
azureClientId = "96eecda7-19ea-49cc-abb5-240097d554f5"
29+
azureRedirectURL = "localhost:8030"
3030

31-
awsClientId = "databricks-sql-connector"
32-
awsRedirctURL = "localhost:8030"
31+
awsClientId = "databricks-sql-connector"
32+
awsRedirectURL = "localhost:8030"
3333
)
3434

3535
func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticator, error) {
@@ -39,10 +39,10 @@ func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticato
3939
var clientID, redirectURL string
4040
if cloud == oauth.AWS {
4141
clientID = awsClientId
42-
redirectURL = awsRedirctURL
42+
redirectURL = awsRedirectURL
4343
} else if cloud == oauth.Azure {
4444
clientID = azureClientId
45-
redirectURL = azureRedirctURL
45+
redirectURL = azureRedirectURL
4646
} else {
4747
return nil, errors.New("unhandled cloud type: " + cloud.String())
4848
}

driverctx/ctx.go

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ func CorrelationIdFromContext(ctx context.Context) string {
3838
}
3939

4040
// NewContextWithConnId creates a new context with connectionId value.
41+
// The connection ID will be displayed in log messages and other dianostic information.
4142
func NewContextWithConnId(ctx context.Context, connId string) context.Context {
4243
if callback, ok := ctx.Value(ConnIdCallbackKey).(IdCallbackFunc); ok {
4344
callback(connId)
@@ -59,6 +60,7 @@ func ConnIdFromContext(ctx context.Context) string {
5960
}
6061

6162
// NewContextWithQueryId creates a new context with queryId value.
63+
// The query id will be displayed in log messages and other diagnostic information.
6264
func NewContextWithQueryId(ctx context.Context, queryId string) context.Context {
6365
if callback, ok := ctx.Value(QueryIdCallbackKey).(IdCallbackFunc); ok {
6466
callback(queryId)

0 commit comments

Comments
 (0)